"vllm/vscode:/vscode.git/clone" did not exist on "df2503e125f3c869b0f274e64530d09bf01ea30d"
Unverified Commit ab1a6a43 authored by mikaylagawarecki's avatar mikaylagawarecki Committed by GitHub
Browse files

[3/n] Migrate cutlass/scaled_mm_entry.cu torch stable ABI (#37221)


Signed-off-by: default avatarMikayla Gawarecki <mikaylagawarecki@gmail.com>
parent b5e60825
This diff is collapsed.
...@@ -6,13 +6,15 @@ ...@@ -6,13 +6,15 @@
#include <cstdio> #include <cstdio>
#include <cstdlib> #include <cstdlib>
#include <torch/headeronly/util/shim_utils.h>
/** /**
* Helper function for checking CUTLASS errors * Helper function for checking CUTLASS errors
*/ */
#define CUTLASS_CHECK(status) \ #define CUTLASS_CHECK(status) \
{ \ { \
cutlass::Status error = status; \ cutlass::Status error = status; \
TORCH_CHECK(error == cutlass::Status::kSuccess, \ STD_TORCH_CHECK(error == cutlass::Status::kSuccess, \
cutlassGetStatusString(error)); \ cutlassGetStatusString(error)); \
} }
......
...@@ -3,6 +3,14 @@ ...@@ -3,6 +3,14 @@
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp" #include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp" #include "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp"
// This header is shared by both _C (unstable ABI) and _C_stable_libtorch
// (stable ABI) targets. When compiled under the stable ABI target,
// TORCH_TARGET_VERSION is defined and Tensor is unavailable, so we
// use torch::stable::Tensor instead.
#ifdef TORCH_TARGET_VERSION
#include <torch/csrc/stable/tensor.h>
#endif
/* /*
This file defines custom epilogues for fusing channel scales, token scales, This file defines custom epilogues for fusing channel scales, token scales,
bias, and activation zero-points onto a GEMM operation using the bias, and activation zero-points onto a GEMM operation using the
...@@ -15,6 +23,12 @@ ...@@ -15,6 +23,12 @@
namespace vllm::c3x { namespace vllm::c3x {
#ifdef TORCH_TARGET_VERSION
using TensorType = torch::stable::Tensor;
#else
using TensorType = torch::Tensor;
#endif
using namespace cute; using namespace cute;
template <typename T> template <typename T>
...@@ -84,7 +98,7 @@ struct ScaledEpilogueBase { ...@@ -84,7 +98,7 @@ struct ScaledEpilogueBase {
// from a tensor. It can handle both row and column, as well as row/column or // from a tensor. It can handle both row and column, as well as row/column or
// scalar cases. // scalar cases.
template <typename Descriptor, typename T> template <typename Descriptor, typename T>
static auto args_from_tensor(torch::Tensor const& tensor) { static auto args_from_tensor(TensorType const& tensor) {
using Arguments = typename Descriptor::Arguments; using Arguments = typename Descriptor::Arguments;
auto* data_ptr = static_cast<T*>(tensor.data_ptr()); auto* data_ptr = static_cast<T*>(tensor.data_ptr());
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> || if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
...@@ -100,7 +114,7 @@ struct ScaledEpilogueBase { ...@@ -100,7 +114,7 @@ struct ScaledEpilogueBase {
// This overload handles the case where there might not be a tensor, in which // This overload handles the case where there might not be a tensor, in which
// case a nullptr is passed and a constant (0) is used. // case a nullptr is passed and a constant (0) is used.
template <typename Descriptor, typename T> template <typename Descriptor, typename T>
static auto args_from_tensor(std::optional<torch::Tensor> const& tensor) { static auto args_from_tensor(std::optional<TensorType> const& tensor) {
using Arguments = typename Descriptor::Arguments; using Arguments = typename Descriptor::Arguments;
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr; auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> || static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> ||
...@@ -158,8 +172,8 @@ struct ScaledEpilogue ...@@ -158,8 +172,8 @@ struct ScaledEpilogue
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>; cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
using ArgumentType = typename EVTCompute::Arguments; using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales, static ArgumentType prepare_args(TensorType const& a_scales,
torch::Tensor const& b_scales) { TensorType const& b_scales) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales); auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales); auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
...@@ -203,9 +217,9 @@ struct ScaledEpilogueBias ...@@ -203,9 +217,9 @@ struct ScaledEpilogueBias
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>; cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments; using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales, static ArgumentType prepare_args(TensorType const& a_scales,
torch::Tensor const& b_scales, TensorType const& b_scales,
torch::Tensor const& bias) { TensorType const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales); auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales); auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias); auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
...@@ -246,9 +260,9 @@ struct ScaledEpilogueColumnBias ...@@ -246,9 +260,9 @@ struct ScaledEpilogueColumnBias
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>; cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments; using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales, static ArgumentType prepare_args(TensorType const& a_scales,
torch::Tensor const& b_scales, TensorType const& b_scales,
torch::Tensor const& bias) { TensorType const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales); auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales); auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias); auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
...@@ -304,10 +318,10 @@ struct ScaledEpilogueBiasAzp ...@@ -304,10 +318,10 @@ struct ScaledEpilogueBiasAzp
EVTComputeScaleB, Bias>; EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments; using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales, static ArgumentType prepare_args(TensorType const& a_scales,
torch::Tensor const& b_scales, TensorType const& b_scales,
torch::Tensor const& azp_adj, TensorType const& azp_adj,
std::optional<torch::Tensor> const& bias) { std::optional<TensorType> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales); auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales); auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias); auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
...@@ -380,11 +394,11 @@ struct ScaledEpilogueBiasAzpToken ...@@ -380,11 +394,11 @@ struct ScaledEpilogueBiasAzpToken
EVTComputeScaleB, Bias>; EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments; using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales, static ArgumentType prepare_args(TensorType const& a_scales,
torch::Tensor const& b_scales, TensorType const& b_scales,
torch::Tensor const& azp_adj, TensorType const& azp_adj,
torch::Tensor const& azp, TensorType const& azp,
std::optional<torch::Tensor> const& bias) { std::optional<TensorType> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales); auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales); auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias); auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
......
#pragma once #pragma once
#include <torch/csrc/stable/tensor.h>
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp" #include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp"
/* /*
...@@ -52,7 +54,7 @@ struct ScaledEpilogueBase { ...@@ -52,7 +54,7 @@ struct ScaledEpilogueBase {
// from a tensor. It can handle both row and column, as well as row/column or // from a tensor. It can handle both row and column, as well as row/column or
// scalar cases. // scalar cases.
template <typename Descriptor, typename T> template <typename Descriptor, typename T>
static auto args_from_tensor(torch::Tensor const& tensor) { static auto args_from_tensor(torch::stable::Tensor const& tensor) {
using Arguments = typename Descriptor::Arguments; using Arguments = typename Descriptor::Arguments;
auto* data_ptr = static_cast<T*>(tensor.data_ptr()); auto* data_ptr = static_cast<T*>(tensor.data_ptr());
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> || if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
...@@ -68,7 +70,8 @@ struct ScaledEpilogueBase { ...@@ -68,7 +70,8 @@ struct ScaledEpilogueBase {
// This overload handles the case where there might not be a tensor, in which // This overload handles the case where there might not be a tensor, in which
// case a nullptr is passed and a constant (0) is used. // case a nullptr is passed and a constant (0) is used.
template <typename Descriptor, typename T> template <typename Descriptor, typename T>
static auto args_from_tensor(std::optional<torch::Tensor> const& tensor) { static auto args_from_tensor(
std::optional<torch::stable::Tensor> const& tensor) {
static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>); static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
using Arguments = typename Descriptor::Arguments; using Arguments = typename Descriptor::Arguments;
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr; auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
...@@ -117,8 +120,8 @@ struct ScaledEpilogue ...@@ -117,8 +120,8 @@ struct ScaledEpilogue
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>; cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
using ArgumentType = typename EVTCompute::Arguments; using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales, static ArgumentType prepare_args(torch::stable::Tensor const& a_scales,
torch::Tensor const& b_scales) { torch::stable::Tensor const& b_scales) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales); auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales); auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
...@@ -160,9 +163,9 @@ struct ScaledEpilogueBias ...@@ -160,9 +163,9 @@ struct ScaledEpilogueBias
using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
EVTCompute0, Bias>; EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments; using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales, static ArgumentType prepare_args(torch::stable::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::stable::Tensor const& b_scales,
torch::Tensor const& bias) { torch::stable::Tensor const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales); auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales); auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias); auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
...@@ -220,10 +223,11 @@ struct ScaledEpilogueBiasAzp ...@@ -220,10 +223,11 @@ struct ScaledEpilogueBiasAzp
using ArgumentType = typename EVTCompute::Arguments; using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales, static ArgumentType prepare_args(
torch::Tensor const& b_scales, torch::stable::Tensor const& a_scales,
torch::Tensor const& azp_adj, torch::stable::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias) { torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales); auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales); auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias); auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
...@@ -298,11 +302,11 @@ struct ScaledEpilogueBiasAzpToken ...@@ -298,11 +302,11 @@ struct ScaledEpilogueBiasAzpToken
using ArgumentType = typename EVTCompute::Arguments; using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales, static ArgumentType prepare_args(
torch::Tensor const& b_scales, torch::stable::Tensor const& a_scales,
torch::Tensor const& azp_adj, torch::stable::Tensor const& b_scales,
torch::Tensor const& azp, torch::stable::Tensor const& azp_adj, torch::stable::Tensor const& azp,
std::optional<torch::Tensor> const& bias) { std::optional<torch::stable::Tensor> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales); auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales); auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias); auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
......
...@@ -27,4 +27,61 @@ void per_token_group_quant_int8(const torch::stable::Tensor& input, ...@@ -27,4 +27,61 @@ void per_token_group_quant_int8(const torch::stable::Tensor& input,
torch::stable::Tensor& output_s, torch::stable::Tensor& output_s,
int64_t group_size, double eps, double int8_min, int64_t group_size, double eps, double int8_min,
double int8_max); double int8_max);
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability);
bool cutlass_group_gemm_supported(int64_t cuda_device_capability);
void cutlass_scaled_mm(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_moe_mm(torch::stable::Tensor& out_tensors,
torch::stable::Tensor const& a_tensors,
torch::stable::Tensor const& b_tensors,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& expert_offsets,
torch::stable::Tensor const& problem_sizes,
torch::stable::Tensor const& a_strides,
torch::stable::Tensor const& b_strides,
torch::stable::Tensor const& c_strides, bool per_act_token,
bool per_out_ch);
void cutlass_scaled_mm_azp(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias);
void get_cutlass_moe_mm_data(
const torch::stable::Tensor& topk_ids,
torch::stable::Tensor& expert_offsets,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2,
torch::stable::Tensor& input_permutation,
torch::stable::Tensor& output_permutation, const int64_t num_experts,
const int64_t n, const int64_t k,
const std::optional<torch::stable::Tensor>& blockscale_offsets,
const bool is_gated);
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
const torch::stable::Tensor& expert_first_token_offset,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2, const int64_t n, const int64_t k,
const bool swap_ab);
void get_cutlass_batched_moe_mm_data(
torch::stable::Tensor& expert_offsets,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2,
const torch::stable::Tensor& expert_num_tokens,
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
const int64_t k);
#endif #endif
...@@ -2,9 +2,10 @@ ...@@ -2,9 +2,10 @@
// clang-format will break include orders // clang-format will break include orders
// clang-format off // clang-format off
#include <torch/all.h> #include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/ops.h>
#include <ATen/cuda/CUDAContext.h> #include "libtorch_stable/torch_utils.h"
#include "cutlass/cutlass.h" #include "cutlass/cutlass.h"
...@@ -25,14 +26,14 @@ ...@@ -25,14 +26,14 @@
namespace vllm::c3x { namespace vllm::c3x {
static inline cute::Shape<int, int, int, int> get_problem_shape( static inline cute::Shape<int, int, int, int> get_problem_shape(
torch::Tensor const& a, torch::Tensor const& b) { torch::stable::Tensor const& a, torch::stable::Tensor const& b) {
int32_t m = a.size(0), n = b.size(1), k = a.size(1); int32_t m = a.size(0), n = b.size(1), k = a.size(1);
return {m, n, k, 1}; return {m, n, k, 1};
} }
template <typename GemmKernel> template <typename GemmKernel>
void cutlass_gemm_caller( void cutlass_gemm_caller(
torch::Device device, cute::Shape<int, int, int, int> prob_shape, torch::stable::Device device, cute::Shape<int, int, int, int> prob_shape,
typename GemmKernel::MainloopArguments mainloop_args, typename GemmKernel::MainloopArguments mainloop_args,
typename GemmKernel::EpilogueArguments epilogue_args, typename GemmKernel::EpilogueArguments epilogue_args,
typename GemmKernel::TileSchedulerArguments scheduler = {}) { typename GemmKernel::TileSchedulerArguments scheduler = {}) {
...@@ -50,19 +51,20 @@ void cutlass_gemm_caller( ...@@ -50,19 +51,20 @@ void cutlass_gemm_caller(
CUTLASS_CHECK(gemm_op.can_implement(args)); CUTLASS_CHECK(gemm_op.can_implement(args));
size_t workspace_size = gemm_op.get_workspace_size(args); size_t workspace_size = gemm_op.get_workspace_size(args);
auto const workspace_options = auto workspace =
torch::TensorOptions().dtype(torch::kUInt8).device(device); torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
auto workspace = torch::empty(workspace_size, workspace_options); std::nullopt, device);
auto stream = at::cuda::getCurrentCUDAStream(device.index()); auto stream = get_current_cuda_stream(device.index());
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
CUTLASS_CHECK(status); CUTLASS_CHECK(status);
} }
template <typename Gemm, typename... EpilogueArgs> template <typename Gemm, typename... EpilogueArgs>
void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, void cutlass_gemm_caller(torch::stable::Tensor& out,
torch::Tensor const& b, torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
EpilogueArgs&&... epilogue_params) { EpilogueArgs&&... epilogue_params) {
using ElementAB = typename Gemm::ElementAB; using ElementAB = typename Gemm::ElementAB;
using ElementC = typename Gemm::ElementC; using ElementC = typename Gemm::ElementC;
......
...@@ -4,13 +4,12 @@ ...@@ -4,13 +4,12 @@
namespace vllm { namespace vllm {
void cutlass_scaled_mm_azp_sm90_int8(torch::Tensor& out, torch::Tensor const& a, void cutlass_scaled_mm_azp_sm90_int8(
torch::Tensor const& b, torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::Tensor const& a_scales, torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
torch::Tensor const& azp_adj, std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::Tensor> const& azp, std::optional<torch::stable::Tensor> const& bias) {
std::optional<torch::Tensor> const& bias) {
if (azp) { if (azp) {
return cutlass_scaled_mm_sm90_int8_epilogue< return cutlass_scaled_mm_sm90_int8_epilogue<
c3x::ScaledEpilogueBiasAzpToken>(out, a, b, a_scales, b_scales, azp_adj, c3x::ScaledEpilogueBiasAzpToken>(out, a, b, a_scales, b_scales, azp_adj,
......
...@@ -4,17 +4,16 @@ ...@@ -4,17 +4,16 @@
namespace vllm { namespace vllm {
void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out, void cutlass_scaled_mm_blockwise_sm100_fp8(
torch::Tensor const& a, torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::Tensor const& b, torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::Tensor const& a_scales, torch::stable::Tensor const& b_scales) {
torch::Tensor const& b_scales) { if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
if (out.dtype() == torch::kBFloat16) {
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::bfloat16_t>( cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::bfloat16_t>(
out, a, b, a_scales, b_scales); out, a, b, a_scales, b_scales);
} else { } else {
TORCH_CHECK(out.dtype() == torch::kFloat16); STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::half_t>( cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::half_t>(
out, a, b, a_scales, b_scales); out, a, b, a_scales, b_scales);
} }
......
#pragma once #pragma once
#include <torch/headeronly/util/shim_utils.h>
#include "cuda_utils.h" #include "cuda_utils.h"
#include "cutlass/cutlass.h" #include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h" #include "cutlass/numeric_types.h"
...@@ -130,10 +132,10 @@ struct cutlass_3x_gemm_fp8_blockwise { ...@@ -130,10 +132,10 @@ struct cutlass_3x_gemm_fp8_blockwise {
}; };
template <typename Gemm> template <typename Gemm>
void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::Tensor const& b, torch::stable::Tensor const& b,
torch::Tensor const& a_scales, torch::stable::Tensor const& a_scales,
torch::Tensor const& b_scales) { torch::stable::Tensor const& b_scales) {
static constexpr bool swap_ab = Gemm::swap_ab; static constexpr bool swap_ab = Gemm::swap_ab;
using GemmKernel = typename Gemm::GemmKernel; using GemmKernel = typename Gemm::GemmKernel;
using StrideA = typename Gemm::GemmKernel::StrideA; using StrideA = typename Gemm::GemmKernel::StrideA;
...@@ -200,11 +202,11 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, ...@@ -200,11 +202,11 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
} }
template <typename OutType> template <typename OutType>
void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out, void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::stable::Tensor& out,
torch::Tensor const& a, torch::stable::Tensor const& a,
torch::Tensor const& b, torch::stable::Tensor const& b,
torch::Tensor const& a_scales, torch::stable::Tensor const& a_scales,
torch::Tensor const& b_scales) { torch::stable::Tensor const& b_scales) {
int32_t m = a.size(0), n = b.size(1), k = a.size(1), sms; int32_t m = a.size(0), n = b.size(1), k = a.size(1), sms;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device()); cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device());
......
...@@ -4,17 +4,16 @@ ...@@ -4,17 +4,16 @@
namespace vllm { namespace vllm {
void cutlass_scaled_mm_blockwise_sm120_fp8(torch::Tensor& out, void cutlass_scaled_mm_blockwise_sm120_fp8(
torch::Tensor const& a, torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::Tensor const& b, torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::Tensor const& a_scales, torch::stable::Tensor const& b_scales) {
torch::Tensor const& b_scales) { if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
if (out.dtype() == torch::kBFloat16) {
cutlass_gemm_blockwise_sm120_fp8_dispatch<cutlass::bfloat16_t>( cutlass_gemm_blockwise_sm120_fp8_dispatch<cutlass::bfloat16_t>(
out, a, b, a_scales, b_scales); out, a, b, a_scales, b_scales);
} else { } else {
TORCH_CHECK(out.dtype() == torch::kFloat16); STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
cutlass_gemm_blockwise_sm120_fp8_dispatch<cutlass::half_t>( cutlass_gemm_blockwise_sm120_fp8_dispatch<cutlass::half_t>(
out, a, b, a_scales, b_scales); out, a, b, a_scales, b_scales);
} }
......
#pragma once #pragma once
#include <torch/headeronly/util/shim_utils.h>
#include "cuda_utils.h" #include "cuda_utils.h"
#include "cutlass/cutlass.h" #include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h" #include "cutlass/numeric_types.h"
...@@ -138,10 +140,10 @@ struct sm120_blockwise_fp8_config_M64 { ...@@ -138,10 +140,10 @@ struct sm120_blockwise_fp8_config_M64 {
}; };
template <typename Gemm> template <typename Gemm>
void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::Tensor const& b, torch::stable::Tensor const& b,
torch::Tensor const& a_scales, torch::stable::Tensor const& a_scales,
torch::Tensor const& b_scales) { torch::stable::Tensor const& b_scales) {
using GemmKernel = typename Gemm::GemmKernel; using GemmKernel = typename Gemm::GemmKernel;
using StrideA = typename Gemm::GemmKernel::StrideA; using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB; using StrideB = typename Gemm::GemmKernel::StrideB;
...@@ -196,11 +198,11 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, ...@@ -196,11 +198,11 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
} }
template <typename OutType> template <typename OutType>
void cutlass_gemm_blockwise_sm120_fp8_dispatch(torch::Tensor& out, void cutlass_gemm_blockwise_sm120_fp8_dispatch(torch::stable::Tensor& out,
torch::Tensor const& a, torch::stable::Tensor const& a,
torch::Tensor const& b, torch::stable::Tensor const& b,
torch::Tensor const& a_scales, torch::stable::Tensor const& a_scales,
torch::Tensor const& b_scales) { torch::stable::Tensor const& b_scales) {
int M = a.size(0); int M = a.size(0);
if (M <= 256) { if (M <= 256) {
using Gemm = typename sm120_blockwise_fp8_config_M64<OutType>::Gemm; using Gemm = typename sm120_blockwise_fp8_config_M64<OutType>::Gemm;
......
...@@ -5,17 +5,16 @@ ...@@ -5,17 +5,16 @@
namespace vllm { namespace vllm {
void cutlass_scaled_mm_blockwise_sm90_fp8(torch::Tensor& out, void cutlass_scaled_mm_blockwise_sm90_fp8(
torch::Tensor const& a, torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::Tensor const& b, torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::Tensor const& a_scales, torch::stable::Tensor const& b_scales) {
torch::Tensor const& b_scales) { if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
if (out.dtype() == torch::kBFloat16) {
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::bfloat16_t>( cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::bfloat16_t>(
out, a, b, a_scales, b_scales); out, a, b, a_scales, b_scales);
} else { } else {
TORCH_CHECK(out.dtype() == torch::kFloat16); STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::half_t>( cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::half_t>(
out, a, b, a_scales, b_scales); out, a, b, a_scales, b_scales);
} }
......
#pragma once #pragma once
#include <torch/headeronly/util/shim_utils.h>
#include "cutlass/cutlass.h" #include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h" #include "cutlass/numeric_types.h"
...@@ -101,10 +103,10 @@ struct cutlass_3x_gemm_fp8_blockwise { ...@@ -101,10 +103,10 @@ struct cutlass_3x_gemm_fp8_blockwise {
}; };
template <typename Gemm> template <typename Gemm>
void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::Tensor const& b, torch::stable::Tensor const& b,
torch::Tensor const& a_scales, torch::stable::Tensor const& a_scales,
torch::Tensor const& b_scales) { torch::stable::Tensor const& b_scales) {
using GemmKernel = typename Gemm::GemmKernel; using GemmKernel = typename Gemm::GemmKernel;
using StrideA = typename Gemm::GemmKernel::StrideA; using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB; using StrideB = typename Gemm::GemmKernel::StrideB;
...@@ -120,7 +122,7 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, ...@@ -120,7 +122,7 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
int32_t m = a.size(0), n = b.size(1), k = a.size(1); int32_t m = a.size(0), n = b.size(1), k = a.size(1);
TORCH_CHECK(m % 4 == 0, "m must be divisible by 4"); STD_TORCH_CHECK(m % 4 == 0, "m must be divisible by 4");
StrideA a_stride; StrideA a_stride;
StrideB b_stride; StrideB b_stride;
...@@ -161,11 +163,11 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, ...@@ -161,11 +163,11 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
} }
template <typename OutType> template <typename OutType>
void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out, void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::stable::Tensor& out,
torch::Tensor const& a, torch::stable::Tensor const& a,
torch::Tensor const& b, torch::stable::Tensor const& b,
torch::Tensor const& a_scales, torch::stable::Tensor const& a_scales,
torch::Tensor const& b_scales) { torch::stable::Tensor const& b_scales) {
// TODO: better heuristics // TODO: better heuristics
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise< cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
OutType, 1, 128, 128, Shape<_128, _128, _128>, OutType, 1, 128, 128, Shape<_128, _128, _128>,
......
#include <torch/all.h> #include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/core/ScalarType.h>
#include "cuda_utils.h" #include "cuda_utils.h"
#include "cutlass_extensions/common.hpp" #include "cutlass_extensions/common.hpp"
template <typename Fp8Func, typename Int8Func, typename BlockwiseFunc> template <typename Fp8Func, typename Int8Func, typename BlockwiseFunc>
void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a, void dispatch_scaled_mm(torch::stable::Tensor& c,
torch::Tensor const& b, torch::Tensor const& a_scales, torch::stable::Tensor const& a,
torch::Tensor const& b_scales, torch::stable::Tensor const& b,
std::optional<torch::Tensor> const& bias, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias,
Fp8Func fp8_func, Int8Func int8_func, Fp8Func fp8_func, Int8Func int8_func,
BlockwiseFunc blockwise_func) { BlockwiseFunc blockwise_func) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32); STD_TORCH_CHECK(a_scales.scalar_type() ==
TORCH_CHECK(b_scales.dtype() == torch::kFloat32); torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
int M = a.size(0), N = b.size(1), K = a.size(1); int M = a.size(0), N = b.size(1), K = a.size(1);
if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) && if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
(b_scales.numel() == 1 || b_scales.numel() == b.size(1))) { (b_scales.numel() == 1 || b_scales.numel() == b.size(1))) {
// Standard per-tensor/per-token/per-channel scaling // Standard per-tensor/per-token/per-channel scaling
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (a.dtype() == torch::kFloat8_e4m3fn) { if (a.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn) {
fp8_func(c, a, b, a_scales, b_scales, bias); fp8_func(c, a, b, a_scales, b_scales, bias);
} else { } else {
TORCH_CHECK(a.dtype() == torch::kInt8); STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
if constexpr (!std::is_same_v<Int8Func, std::nullptr_t>) { if constexpr (!std::is_same_v<Int8Func, std::nullptr_t>) {
int8_func(c, a, b, a_scales, b_scales, bias); int8_func(c, a, b, a_scales, b_scales, bias);
} else { } else {
int32_t version_num = get_sm_version_num(); int32_t version_num = get_sm_version_num();
TORCH_CHECK( STD_TORCH_CHECK(
false, "Int8 not supported on SM", version_num, false, "Int8 not supported on SM", version_num,
". Use FP8 quantization instead, or run on older arch (SM < 100)."); ". Use FP8 quantization instead, or run on older arch (SM < 100).");
} }
} }
} else { } else {
TORCH_CHECK(a_scales.dim() == 2, "a scale must be 2d tensor."); STD_TORCH_CHECK(a_scales.dim() == 2, "a scale must be 2d tensor.");
TORCH_CHECK(b_scales.dim() == 2, "b scale must be 2d tensor."); STD_TORCH_CHECK(b_scales.dim() == 2, "b scale must be 2d tensor.");
int32_t version_num = get_sm_version_num(); int32_t version_num = get_sm_version_num();
if (version_num >= 90) { if (version_num >= 90) {
TORCH_CHECK( STD_TORCH_CHECK(
a.size(0) == a_scales.size(0) && a.size(0) == a_scales.size(0) &&
cuda_utils::ceil_div(a.size(1), int64_t(128)) == a_scales.size(1), cuda_utils::ceil_div(a.size(1), int64_t(128)) == a_scales.size(1),
"a_scale_group_shape must be [1, 128]."); "a_scale_group_shape must be [1, 128].");
TORCH_CHECK( STD_TORCH_CHECK(
cuda_utils::ceil_div(b.size(0), int64_t(128)) == b_scales.size(0) && cuda_utils::ceil_div(b.size(0), int64_t(128)) == b_scales.size(0) &&
cuda_utils::ceil_div(b.size(1), int64_t(128)) == b_scales.size(1), cuda_utils::ceil_div(b.size(1), int64_t(128)) == b_scales.size(1),
"b_scale_group_shape must be [128, 128]."); "b_scale_group_shape must be [128, 128].");
} }
TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm"); STD_TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
blockwise_func(c, a, b, a_scales, b_scales); blockwise_func(c, a, b, a_scales, b_scales);
} }
} }
#pragma once
#include <torch/csrc/stable/tensor.h>
namespace vllm {
void cutlass_scaled_mm_sm90_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_scaled_mm_sm90_int8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_scaled_mm_azp_sm90_int8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_scaled_mm_blockwise_sm90_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales);
void cutlass_scaled_mm_sm100_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_scaled_mm_sm120_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_scaled_mm_blockwise_sm100_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales);
void cutlass_scaled_mm_blockwise_sm120_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales);
} // namespace vllm
...@@ -3,15 +3,16 @@ ...@@ -3,15 +3,16 @@
namespace vllm { namespace vllm {
void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a, void cutlass_scaled_mm_sm100_fp8(
torch::Tensor const& b, torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::Tensor const& a_scales, torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::stable::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias) { std::optional<torch::stable::Tensor> const& bias) {
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (bias) { if (bias) {
TORCH_CHECK(bias->dtype() == out.dtype(), STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
"currently bias dtype must match output dtype ", out.dtype()); "currently bias dtype must match output dtype ",
out.scalar_type());
return cutlass_scaled_mm_sm100_fp8_epilogue<true>(out, a, b, a_scales, return cutlass_scaled_mm_sm100_fp8_epilogue<true>(out, a, b, a_scales,
b_scales, *bias); b_scales, *bias);
} else { } else {
......
#pragma once #pragma once
#include <torch/headeronly/util/shim_utils.h>
#include "scaled_mm.cuh" #include "scaled_mm.cuh"
#include "cutlass_gemm_caller.cuh" #include "cutlass_gemm_caller.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
...@@ -192,8 +194,9 @@ struct sm100_fp8_config_M16_swap_ab { ...@@ -192,8 +194,9 @@ struct sm100_fp8_config_M16_swap_ab {
}; };
template <typename Gemm, typename... EpilogueArgs> template <typename Gemm, typename... EpilogueArgs>
void cutlass_gemm_caller_sm100_fp8(torch::Tensor& out, torch::Tensor const& a, void cutlass_gemm_caller_sm100_fp8(torch::stable::Tensor& out,
torch::Tensor const& b, torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
EpilogueArgs&&... epilogue_params) { EpilogueArgs&&... epilogue_params) {
static constexpr bool swap_ab = Gemm::swap_ab; static constexpr bool swap_ab = Gemm::swap_ab;
using ElementAB = typename Gemm::ElementAB; using ElementAB = typename Gemm::ElementAB;
...@@ -237,15 +240,15 @@ void cutlass_gemm_caller_sm100_fp8(torch::Tensor& out, torch::Tensor const& a, ...@@ -237,15 +240,15 @@ void cutlass_gemm_caller_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
template <typename InType, typename OutType, bool EnableBias, template <typename InType, typename OutType, bool EnableBias,
typename... EpilogueArgs> typename... EpilogueArgs>
inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out, inline void cutlass_gemm_sm100_fp8_dispatch(
torch::Tensor const& a, torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::Tensor const& b, torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::Tensor const& a_scales, torch::stable::Tensor const& b_scales, EpilogueArgs&&... args) {
torch::Tensor const& b_scales,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); STD_TORCH_CHECK(a.scalar_type() ==
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); torch::headeronly::ScalarType::Float8_e4m3fn);
STD_TORCH_CHECK(b.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
using Cutlass3xGemmDefault = using Cutlass3xGemmDefault =
typename sm100_fp8_config_default<InType, OutType, typename sm100_fp8_config_default<InType, OutType,
...@@ -292,22 +295,24 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out, ...@@ -292,22 +295,24 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
} }
template <bool EnableBias, typename... EpilogueArgs> template <bool EnableBias, typename... EpilogueArgs>
void cutlass_scaled_mm_sm100_fp8_epilogue(torch::Tensor& out, void cutlass_scaled_mm_sm100_fp8_epilogue(torch::stable::Tensor& out,
torch::Tensor const& a, torch::stable::Tensor const& a,
torch::Tensor const& b, torch::stable::Tensor const& b,
torch::Tensor const& a_scales, torch::stable::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::stable::Tensor const& b_scales,
EpilogueArgs&&... epilogue_args) { EpilogueArgs&&... epilogue_args) {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); STD_TORCH_CHECK(a.scalar_type() ==
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); torch::headeronly::ScalarType::Float8_e4m3fn);
STD_TORCH_CHECK(b.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
if (out.dtype() == torch::kBFloat16) { if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
return cutlass_gemm_sm100_fp8_dispatch<cutlass::float_e4m3_t, return cutlass_gemm_sm100_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::bfloat16_t, EnableBias>( cutlass::bfloat16_t, EnableBias>(
out, a, b, a_scales, b_scales, out, a, b, a_scales, b_scales,
std::forward<EpilogueArgs>(epilogue_args)...); std::forward<EpilogueArgs>(epilogue_args)...);
} else { } else {
TORCH_CHECK(out.dtype() == torch::kFloat16); STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
return cutlass_gemm_sm100_fp8_dispatch<cutlass::float_e4m3_t, return cutlass_gemm_sm100_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::half_t, EnableBias>( cutlass::half_t, EnableBias>(
out, a, b, a_scales, b_scales, out, a, b, a_scales, b_scales,
......
...@@ -4,15 +4,16 @@ ...@@ -4,15 +4,16 @@
namespace vllm { namespace vllm {
void cutlass_scaled_mm_sm120_fp8(torch::Tensor& out, torch::Tensor const& a, void cutlass_scaled_mm_sm120_fp8(
torch::Tensor const& b, torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::Tensor const& a_scales, torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::stable::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias) { std::optional<torch::stable::Tensor> const& bias) {
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (bias) { if (bias) {
TORCH_CHECK(bias->dtype() == out.dtype(), STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
"currently bias dtype must match output dtype ", out.dtype()); "currently bias dtype must match output dtype ",
out.scalar_type());
return cutlass_scaled_mm_sm120_fp8_epilogue<c3x::ScaledEpilogueBias>( return cutlass_scaled_mm_sm120_fp8_epilogue<c3x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias); out, a, b, a_scales, b_scales, *bias);
} else { } else {
......
#pragma once #pragma once
#include <torch/headeronly/util/shim_utils.h>
#include "scaled_mm.cuh" #include "scaled_mm.cuh"
#include "cutlass_gemm_caller.cuh" #include "cutlass_gemm_caller.cuh"
...@@ -138,13 +140,15 @@ struct sm120_fp8_config_M16 { ...@@ -138,13 +140,15 @@ struct sm120_fp8_config_M16 {
template <typename InType, typename OutType, template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue, template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs> typename... EpilogueArgs>
inline void cutlass_gemm_sm120_fp8_dispatch(torch::Tensor& out, inline void cutlass_gemm_sm120_fp8_dispatch(torch::stable::Tensor& out,
torch::Tensor const& a, torch::stable::Tensor const& a,
torch::Tensor const& b, torch::stable::Tensor const& b,
EpilogueArgs&&... args) { EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); STD_TORCH_CHECK(a.scalar_type() ==
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); torch::headeronly::ScalarType::Float8_e4m3fn);
STD_TORCH_CHECK(b.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
int M = a.size(0); int M = a.size(0);
...@@ -177,19 +181,21 @@ inline void cutlass_gemm_sm120_fp8_dispatch(torch::Tensor& out, ...@@ -177,19 +181,21 @@ inline void cutlass_gemm_sm120_fp8_dispatch(torch::Tensor& out,
template <template <typename, typename, typename> typename Epilogue, template <template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs> typename... EpilogueArgs>
void cutlass_scaled_mm_sm120_fp8_epilogue(torch::Tensor& out, void cutlass_scaled_mm_sm120_fp8_epilogue(torch::stable::Tensor& out,
torch::Tensor const& a, torch::stable::Tensor const& a,
torch::Tensor const& b, torch::stable::Tensor const& b,
EpilogueArgs&&... epilogue_args) { EpilogueArgs&&... epilogue_args) {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); STD_TORCH_CHECK(a.scalar_type() ==
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); torch::headeronly::ScalarType::Float8_e4m3fn);
STD_TORCH_CHECK(b.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
if (out.dtype() == torch::kBFloat16) { if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
return cutlass_gemm_sm120_fp8_dispatch<cutlass::float_e4m3_t, return cutlass_gemm_sm120_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::bfloat16_t, Epilogue>( cutlass::bfloat16_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else { } else {
TORCH_CHECK(out.dtype() == torch::kFloat16); STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
return cutlass_gemm_sm120_fp8_dispatch<cutlass::float_e4m3_t, return cutlass_gemm_sm120_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::half_t, Epilogue>( cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment