Unverified Commit ad4e58bf authored by Shu Wang's avatar Shu Wang Committed by GitHub
Browse files

Support fp8 gemm for blackwell (#4558)

parent bfb03c61
...@@ -792,6 +792,282 @@ void sm90_fp8_dispatch_shape( ...@@ -792,6 +792,282 @@ void sm90_fp8_dispatch_shape(
} }
#endif #endif
#if defined CUDA_VERSION && CUDA_VERSION >= 12080
template <
typename ElementType,
typename OutElementType,
typename AccumElementType,
typename CTAShape,
typename ClusterShape,
typename MainloopScheduleType,
typename EpilogueScheduleType,
typename TileSchedulerType = void,
bool WithBias = false>
struct DeviceGemmFp8RowwiseSm100 {
static_assert(std::is_same_v<ElementType, cutlass::float_e4m3_t>, "ElementType must be FP8(e4m3)");
using TileShape = CTAShape;
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
using ElementComputeEpilogue = float;
using ScaleA = cutlass::epilogue::fusion::Sm90ColBroadcast<
0,
TileShape,
ElementComputeEpilogue,
ElementComputeEpilogue,
cute::Stride<cute::Int<1>, cute::Int<0>, cute::Int<0>>>;
using ScaleB = cutlass::epilogue::fusion::Sm90RowBroadcast<
0,
TileShape,
ElementComputeEpilogue,
ElementComputeEpilogue,
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;
using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<
0,
TileShape,
OutElementType,
OutElementType,
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;
using Compute0 = cutlass::epilogue::fusion::
Sm90Compute<cutlass::multiplies, float, float, cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
using LayoutA = cutlass::layout::RowMajor;
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementType>::value;
using LayoutB = cutlass::layout::ColumnMajor;
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementType>::value;
using ElementC = void;
using LayoutC = cutlass::layout::RowMajor;
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<OutElementType>::value;
using LayoutD = cutlass::layout::RowMajor;
static constexpr int AlignmentD = AlignmentC;
using Compute1MulAdd = cutlass::epilogue::fusion::
Sm90Compute<cutlass::multiply_add, OutElementType, float, cutlass::FloatRoundStyle::round_to_nearest>;
using Compute1Mul = cutlass::epilogue::fusion::
Sm90Compute<cutlass::multiplies, OutElementType, float, cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute = typename std::conditional_t<
WithBias,
cutlass::epilogue::fusion::Sm90EVT<Compute1MulAdd, ScaleA, EVTCompute0, Bias>,
cutlass::epilogue::fusion::Sm90EVT<Compute1Mul, ScaleA, EVTCompute0>>;
using ArgumentType = typename EVTCompute::Arguments;
// MMA type
using ElementAccumulator = AccumElementType;
// Epilogue types
using ElementCompute = float;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm100,
cutlass::arch::OpClassTensorOp,
TileShape,
ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator,
ElementCompute,
ElementC,
LayoutC,
AlignmentC,
OutElementType,
LayoutD,
AlignmentD,
EpilogueScheduleType,
EVTCompute>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100,
cutlass::arch::OpClassTensorOp,
ElementType,
LayoutA,
AlignmentA,
ElementType,
LayoutB,
AlignmentB,
ElementAccumulator,
TileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
MainloopScheduleType>::CollectiveOp;
using GemmKernel =
cutlass::gemm::kernel::GemmUniversal<Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
template <typename Descriptor, typename T>
static auto args_from_tensor(torch::Tensor const& tensor) {
using Arguments = typename Descriptor::Arguments;
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
static_assert(
std::is_same_v<Descriptor, ScaleA> || std::is_same_v<Descriptor, ScaleB> || std::is_same_v<Descriptor, Bias>);
return Arguments{data_ptr};
}
public:
static ArgumentType prepare_args(
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias = std::nullopt) {
auto a_args = args_from_tensor<ScaleA, float>(a_scales);
auto b_args = args_from_tensor<ScaleB, float>(b_scales);
typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
if constexpr (WithBias) {
auto bias_args = args_from_tensor<Bias, OutElementType>(bias.value());
return ArgumentType{a_args, evt0_args, bias_args, {}};
} else {
return ArgumentType{a_args, evt0_args, {}};
}
}
};
template <typename GemmType, bool WithBias>
typename GemmType::Gemm::Arguments prepare_sm100_fp8_args(
torch::Tensor& out,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
using Gemm = typename GemmType::Gemm;
using ElementT = typename Gemm::ElementA;
using ElementC = typename Gemm::ElementC;
using ElementOutput = typename Gemm::ElementD;
using ElementComputeEpilogue = float;
using GemmKernel = typename Gemm::GemmKernel;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = StrideC;
using StrideAux = StrideC;
int32_t m = a.size(0);
int32_t n = b.size(1);
int32_t k = a.size(1);
ElementT const* ptr_a = reinterpret_cast<ElementT const*>(a.data_ptr());
ElementT const* ptr_b = reinterpret_cast<ElementT const*>(b.data_ptr());
StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
StrideC stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1));
StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(m, n, 1));
StrideAux aux_stride = stride_d;
typename GemmKernel::MainloopArguments mainloop_args{ptr_a, stride_a, ptr_b, stride_b};
typename GemmKernel::ProblemShape prob_shape = {m, n, k, 1};
cutlass::KernelHardwareInfo hw_info;
typename GemmKernel::TileSchedulerArguments scheduler = {};
auto ptr_c = static_cast<ElementOutput*>(out.data_ptr());
auto prepare_epilogue_args = [&](const c10::optional<torch::Tensor>& bias = c10::nullopt) {
if constexpr (WithBias) {
TORCH_CHECK(bias.has_value(), "Bias tensor is required but not provided.");
return typename GemmKernel::EpilogueArguments{
GemmType::prepare_args(scales_a, scales_b, bias.value()), ptr_c, stride_c, ptr_c, stride_d};
} else {
return typename GemmKernel::EpilogueArguments{
GemmType::prepare_args(scales_a, scales_b), ptr_c, stride_c, ptr_c, stride_d};
}
};
typename GemmKernel::Arguments args{
cutlass::gemm::GemmUniversalMode::kGemm,
prob_shape,
mainloop_args,
prepare_epilogue_args(bias),
hw_info,
scheduler};
return args;
}
template <typename Gemm, bool WithBias>
void launch_sm100_fp8_scaled_mm(
torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
auto args = prepare_sm100_fp8_args<Gemm, WithBias>(out, a, b, scales_a, scales_b, bias);
typename Gemm::Gemm gemm_op;
size_t workspace_size = gemm_op.get_workspace_size(args);
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto workspace = torch::empty(workspace_size, workspace_options);
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
auto can_implement = gemm_op.can_implement(args);
TORCH_CHECK(can_implement == cutlass::Status::kSuccess)
auto status = gemm_op.run(args, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess)
}
template <typename OutType>
void sm100_fp8_dispatch_bias(
torch::Tensor& out,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
using CTAShape = Shape<_256, _128, _64>;
using ClusterShape = Shape<_2, _2, _1>;
using MainloopScheduleType = cutlass::gemm::collective::KernelScheduleAuto;
using EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto;
using TileSchedulerType = void;
using ElementInput = cutlass::float_e4m3_t;
using ElementOutput = OutType;
using AccumElementType = float;
if (bias) {
using Gemm = DeviceGemmFp8RowwiseSm100<
ElementInput,
ElementOutput,
AccumElementType,
CTAShape,
ClusterShape,
MainloopScheduleType,
EpilogueScheduleType,
TileSchedulerType,
true>;
return launch_sm100_fp8_scaled_mm<Gemm, true>(out, a, b, scales_a, scales_b, bias);
} else {
using Gemm = DeviceGemmFp8RowwiseSm100<
ElementInput,
ElementOutput,
AccumElementType,
CTAShape,
ClusterShape,
MainloopScheduleType,
EpilogueScheduleType,
TileSchedulerType,
false>;
return launch_sm100_fp8_scaled_mm<Gemm, false>(out, a, b, scales_a, scales_b, bias);
}
}
template <typename OutType>
void sm100_fp8_dispatch_shape(
torch::Tensor& out,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
return sm100_fp8_dispatch_bias<OutType>(out, a, b, scales_a, scales_b, bias);
}
#endif
torch::Tensor fp8_scaled_mm( torch::Tensor fp8_scaled_mm(
const torch::Tensor& mat_a, const torch::Tensor& mat_a,
const torch::Tensor& mat_b, const torch::Tensor& mat_b,
...@@ -833,6 +1109,17 @@ torch::Tensor fp8_scaled_mm( ...@@ -833,6 +1109,17 @@ torch::Tensor fp8_scaled_mm(
auto sm_version = getSMVersion(); auto sm_version = getSMVersion();
#if defined CUDA_VERSION && CUDA_VERSION >= 12080
if (sm_version >= 100) {
if (out_dtype == torch::kBFloat16) {
sm100_fp8_dispatch_shape<cutlass::bfloat16_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
} else {
sm100_fp8_dispatch_shape<cutlass::half_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
}
return out;
}
#endif
#if defined CUDA_VERSION && CUDA_VERSION >= 12000 #if defined CUDA_VERSION && CUDA_VERSION >= 12000
if (sm_version >= 90) { if (sm_version >= 90) {
if (out_dtype == torch::kBFloat16) { if (out_dtype == torch::kBFloat16) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment