Unverified Commit dfe5b7df authored by Jan Bielak's avatar Jan Bielak Committed by GitHub
Browse files

[Common][Pytorch] Add support for the FP8 Block Scaling (ie. Deepseek) recipe on Blackwell (#2157)



* Update to_string(NVTEScalingMode) to include block scaling
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Add `nvte_swizzle_block_scaling_to_mxfp8_scaling_factors`
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Convert FP8 block scaling tensors to MXFP8 tensors on Blackwell and newer in GEMM
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Allow Blackwell and newer in Deepseek recipe compatbility check
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Allow data_rows % 4 != 0 in 1d kernel
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Load scaling factors in unswizzled order in 1d kernel
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Enforce use of power of two scaling
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Skip the FP8 block scaling exact GEMM test on Blackwell
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Skip further tests with pow_2_scales=False
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Initial implementation of tensor conversion for grouped gemm
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Skip non power of two scaling cpp unit tests
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Fix handling of all gather
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Apply suggestions from code review
Co-authored-by: default avatarCopilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Use compute capability 10.0 for logic with Blackwell
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Apply suggestions from code review
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

---------
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarCopilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent b840898b
......@@ -501,6 +501,12 @@ TEST_P(FusedCastFloat8BlockwiseTestSuite, TestFusedCastFloat8Blockwise) {
q_opts.amax_epsilon = eps;
q_opts.block_scaling_dim = 2u;
// On Blackwell and newer, the FP8 block scaling recipe is emulated with MXFP8,
// which requires using power of two scaling factors. Skip unsupported tests.
if (getDeviceComputeCapability() >= blackwellComputeCapability && !force_pow_2) {
GTEST_SKIP();
}
if (colwise && matrix_size.size() < 2) {
// test_common Tensor initialization code does not
// handle this case.
......@@ -552,6 +558,12 @@ TEST_P(FusedCastFloat8VectorwiseTestSuite, TestFusedCastFloat8Vectorwise) {
q_opts.amax_epsilon = eps;
q_opts.block_scaling_dim = 1u;
// On Blackwell and newer, the FP8 block scaling recipe is emulated with MXFP8,
// which requires using power of two scaling factors. Skip unsupported tests.
if (getDeviceComputeCapability() >= blackwellComputeCapability && !force_pow_2) {
GTEST_SKIP();
}
if (colwise && matrix_size.size() < 2) {
// test_common Tensor initialization code does not
// handle this case.
......
......@@ -8,6 +8,7 @@ import transformer_engine as te
import transformer_engine_torch as tex
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.utils import get_device_compute_capability
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
Float8BlockQuantizer,
......@@ -19,7 +20,8 @@ from references.blockwise_fp8_gemm_reference import CuBLASRefBlockwiseGemm
def fp8_blockwise_gemm_supported() -> bool:
supported, _ = FP8GlobalStateManager.is_fp8_block_scaling_available()
return supported
emulated = get_device_compute_capability() >= (10, 0)
return supported and not emulated
def cublas_gemm_fp8_blockwise_case(
......
......@@ -12,6 +12,7 @@ import transformer_engine as te
import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.common.recipe import Float8BlockScaling
from transformer_engine.pytorch.utils import get_device_compute_capability
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
Float8BlockQuantizer,
......@@ -32,6 +33,7 @@ tensor_dump_dir_env = os.getenv("NVTE_TEST_BLOCK_CURRENT_SCALING_EXACT_TENSOR_DU
if tensor_dump_dir_env is not None:
TENSOR_DUMP_DIR = pathlib.Path(tensor_dump_dir_env)
recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_fp8_block_scaling_available()
recipe_emulated = get_device_compute_capability() >= (10, 0)
class GetRecipes:
......@@ -218,6 +220,12 @@ def check_quantization_block_tiling_versus_reference(
pow_2_scales: bool,
tile_size: Tuple[int, int],
) -> None:
if recipe_emulated and not pow_2_scales:
pytest.skip(
"On Blackwell and newer, the FP8 block scaling recipe is emulated "
"with MXFP8, which requires using power of two scaling factors."
)
te_dtype = TE_DType[quant_dtype]
if tile_size == (1, 128):
block_scaling_dim = 1
......@@ -409,6 +417,12 @@ def test_quantization_block_tiling_extrema_versus_reference(
tile_size: Tuple[int, int],
extrema_high: bool,
) -> None:
if recipe_emulated and not pow_2_scales:
pytest.skip(
"On Blackwell and newer, the FP8 block scaling recipe is emulated "
"with MXFP8, which requires using power of two scaling factors."
)
# This test runs a single tile through a quantizer as a way to test
# branch coverage of scale computation.
te_dtype = TE_DType[quant_dtype]
......
......@@ -127,6 +127,7 @@ list(APPEND transformer_engine_SOURCES
util/multi_stream.cpp
util/rtc.cpp
swizzle/swizzle.cu
swizzle/swizzle_block_scaling.cu
fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu
......
......@@ -44,6 +44,26 @@ void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cud
void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETensor* outputs,
const size_t num_tensors, cudaStream_t stream);
/*! \brief Swizzling FP8 block scaling scaling factors into mxfp8 interleaved layout for GEMM
*
* \param[in] input Input FP8 block scaling tensor with GEMM_READY scale_inv.
* \param[in,out] output Output mxfp8 tensor which hosts swizzled scale_inv.
* \param[in] stream CUDA stream used for the operation.
*
* This function is used for emulating the FP8 block scaling recipe on Blackwell and newer as it
* not natively supported by cublasLt on architectures other than Hopper.
* Requirements:
* - input is an FP8 block scaling tensor
* - input has rowwise usage
* - input.scale_inv is in GEMM_READY format
* - output is an MXFP8 tensor
* - output has rowwise usage
* - output.scale_inv has appropriate shape
* */
void nvte_swizzle_block_scaling_to_mxfp8_scaling_factors(const NVTETensor input, NVTETensor output,
cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_runtime.h>
#include <transformer_engine/swizzle.h>
#include <cstdint>
#include <type_traits>
#include "../common.h"
#include "../util/logging.h"
#include "transformer_engine/transformer_engine.h"
namespace transformer_engine {
namespace {
constexpr uint32_t WARP_SIZE = 32;
} // namespace
namespace swizzle_kernel_1d {
constexpr uint32_t WARPS_X_PER_TB = 2; // configurable
constexpr uint32_t WARPS_Y_PER_TB = 2; // configurable
// Transposes a 4x4 matrix of bytes stored across four threads with consecutive thread ids where
// each thread stores a single row (of four bytes).
// Example:
// lane0.row = 0x00010203
// lane1.row = 0x04050607
// lane2.row = 0x08090a0b
// lane3.row = 0x0c0d0e0f
// Becomes:
// lane0.row = 0x0004080c
// lane1.row = 0x0105090d
// lane2.row = 0x02060a0e
// lane3.row = 0x03070b0f
uint32_t __device__ __forceinline__ transpose_4x4_byte_matrix(const uint32_t row,
const uint32_t lane,
const uint32_t active_mask) {
using cu = const uint32_t;
// Threads operate in groups of 4, and each thread stores 4 bytes at a time.
// The bytes in this 4x4 matrix are labeled in hex. We shuffle around bytes
// until we have transposed the 4x4 matrix.
cu m_0123_4567_89ab_cdef = row;
cu m_4567_0123_cdef_89ab = __shfl_xor_sync(active_mask, m_0123_4567_89ab_cdef, 1, 4);
cu m_0426_4062_8cae_c8ea = __byte_perm(m_0123_4567_89ab_cdef, m_4567_0123_cdef_89ab, 0x6240);
cu m_5173_1537_d9fb_9dbf = __byte_perm(m_0123_4567_89ab_cdef, m_4567_0123_cdef_89ab, 0x3715);
cu m_0426_1537_8cae_9dbf = (lane & 1) ? m_5173_1537_d9fb_9dbf : m_0426_4062_8cae_c8ea;
cu m_8cae_9dbf_0426_1537 = __shfl_xor_sync(active_mask, m_0426_1537_8cae_9dbf, 2, 4);
cu m_048c_159d_8c04_9d15 = __byte_perm(m_0426_1537_8cae_9dbf, m_8cae_9dbf_0426_1537, 0x5410);
cu m_ae26_bf37_26ae_37bf = __byte_perm(m_0426_1537_8cae_9dbf, m_8cae_9dbf_0426_1537, 0x3276);
cu m_048c_159d_26ae_37bf = (lane & 2) ? m_ae26_bf37_26ae_37bf : m_048c_159d_8c04_9d15;
return m_048c_159d_26ae_37bf;
}
// Expands a uint32_t to a uint4 by duplicating each byte four times.
// Example: 0x01020304u becomes uint4{0x01010101, 0x02020202, 0x03030303, 0x04040404}
uint4 __device__ __forceinline__ broadcast_uint32_t_to_uint4(uint32_t x) {
return {__byte_perm(x, 0, 0x0000), __byte_perm(x, 0, 0x1111), __byte_perm(x, 0, 0x2222),
__byte_perm(x, 0, 0x3333)};
}
// Tag struct denoting whether the number of rows of the input fp8 block scaling tensor's data
// matrix is divisible by 128. If it is not, some threads could read out of bounds scaling factors.
struct no_oob_tag_t {};
constexpr no_oob_tag_t NO_OOB_TAG;
template <typename OOBT>
void __global__ __launch_bounds__(WARPS_X_PER_TB* WARPS_Y_PER_TB* WARP_SIZE)
swizzle_block_scaling_1d_to_mxfp8_scaling_factors_kernel(
const void* __restrict__ const in, void* __restrict__ const out, const uint32_t tiles_x,
const uint32_t tiles_y, const uint32_t in_y_stride, const uint32_t out_y_stride,
OOBT first_oob) {
// resolve kernel variant
constexpr bool no_oob = std::is_same_v<OOBT, no_oob_tag_t>;
static_assert(no_oob || std::is_same_v<OOBT, uint32_t>);
// load thread indices
const uint32_t lane = threadIdx.x;
__builtin_assume(lane < WARP_SIZE);
const uint32_t warp_x = threadIdx.z;
__builtin_assume(warp_x < WARPS_X_PER_TB);
const uint32_t warp_y = threadIdx.y;
__builtin_assume(warp_y < WARPS_Y_PER_TB);
// compute tile indices
const uint32_t out_tile_y = blockIdx.y * WARPS_Y_PER_TB + warp_y;
const uint32_t out_tile_x = blockIdx.x * WARPS_X_PER_TB + warp_x;
const uint32_t in_tile_y = out_tile_x;
const uint32_t in_tile_x = out_tile_y;
// bounds check; uniform branch
if (out_tile_y >= tiles_y || out_tile_x >= tiles_x) {
return;
}
// calculate this warp's input base pointer
constexpr uint32_t in_x_stride = WARP_SIZE * sizeof(uint4);
const void* const warp_src = in + in_tile_y * in_y_stride + in_tile_x * in_x_stride;
// load scaling factors for this lane's initial four 1x128 tiles
uint4 sf;
if constexpr (no_oob) {
sf = reinterpret_cast<const uint4*>(warp_src)[lane];
} else {
if ((out_tile_y < tiles_y - 1) || lane < first_oob) {
sf = reinterpret_cast<const uint4*>(warp_src)[lane];
} else {
sf = uint4{0, 0, 0, 0};
}
}
// pack the exponent bits of the scaling factors
uint32_t packed_exponents = (sf.x >> 23) | (sf.y >> 15) | (sf.z >> 7) | (sf.w << 1);
// partially swizzle the scaling factors
constexpr uint32_t ACTIVE_MASK = 0xFFFFFFFF; // no divergent branches
const uint32_t lane_load_idx = (lane % 4) * 8 + (lane / 4);
packed_exponents = __shfl_sync(ACTIVE_MASK, packed_exponents, lane_load_idx);
// transpose 4x4 matrices of scaling factors
packed_exponents = transpose_4x4_byte_matrix(packed_exponents, lane % 4, ACTIVE_MASK);
// broadcast the scaling factors for sixteen 1x32 tiles
sf = broadcast_uint32_t_to_uint4(packed_exponents);
// store them cooperatively for 512 1x32 tiles in a 128x128 tile
constexpr uint32_t out_x_stride = 512;
void* const warp_dst = out + out_tile_y * out_y_stride + out_tile_x * out_x_stride;
reinterpret_cast<uint4*>(warp_dst)[lane] = sf;
}
void launch_kernel(const void* const in, void* const out, uint32_t data_rows, uint32_t data_cols,
cudaStream_t stream) {
NVTE_CHECK(is_aligned_ptr(in, alignof(uint4)), "Input scaling factor pointer must be aligned to ",
alignof(uint4), " bytes");
NVTE_CHECK(is_aligned_ptr(out, alignof(uint4)),
"Output scaling factor pointer must be aligned to ", alignof(uint4), " bytes");
NVTE_CHECK(data_rows % 4 == 0, "Input tensor must not have any padding scaling factors");
const uint32_t tiles_x = DIVUP(data_cols, 128u);
const uint32_t tiles_y = DIVUP(data_rows, 128u);
const dim3 grid_dim{DIVUP(tiles_x, WARPS_X_PER_TB), DIVUP(tiles_y, WARPS_Y_PER_TB), 1};
const dim3 block_dim{WARP_SIZE, WARPS_Y_PER_TB, WARPS_X_PER_TB};
// Each 128x128 tile in the data corresponds to a 128x1 tile in the input scales
// and a 128x4 tile in the output scales. The input scales are in transposed order.
const uint32_t input_scale_inv_cols = DIVUP(data_rows, 4u) * 4;
const uint32_t output_scale_inv_cols = tiles_x * 128 * 4;
const uint32_t in_y_stride = input_scale_inv_cols * sizeof(float);
const uint32_t out_y_stride = output_scale_inv_cols * sizeof(uint8_t);
const uint32_t first_oob = (input_scale_inv_cols % 128) / 4;
if (first_oob == 0) {
swizzle_block_scaling_1d_to_mxfp8_scaling_factors_kernel<<<grid_dim, block_dim, 0, stream>>>(
in, out, tiles_x, tiles_y, in_y_stride, out_y_stride, NO_OOB_TAG);
} else {
swizzle_block_scaling_1d_to_mxfp8_scaling_factors_kernel<<<grid_dim, block_dim, 0, stream>>>(
in, out, tiles_x, tiles_y, in_y_stride, out_y_stride, first_oob);
}
}
} // namespace swizzle_kernel_1d
namespace swizzle_kernel_2d {
constexpr uint32_t WARPS_X_PER_TB = 2; // configurable
constexpr uint32_t WARPS_Y_PER_TB = 2; // configurable
void __global__ __launch_bounds__(WARPS_X_PER_TB* WARPS_Y_PER_TB* WARP_SIZE)
swizzle_block_scaling_2d_to_mxfp8_scaling_factors_kernel(
const void* __restrict__ const in, void* __restrict__ const out, const uint32_t tiles_x,
const uint32_t tiles_y, const uint32_t in_y_stride, const uint32_t out_y_stride) {
// load thread indices
const uint32_t lane = threadIdx.x;
__builtin_assume(lane < WARP_SIZE);
const uint32_t warp_x = threadIdx.z;
__builtin_assume(warp_x < WARPS_X_PER_TB);
const uint32_t warp_y = threadIdx.y;
__builtin_assume(warp_y < WARPS_Y_PER_TB);
// compute tile indices
const uint32_t out_tile_y = blockIdx.y * WARPS_Y_PER_TB + warp_y;
const uint32_t out_tile_x = blockIdx.x * WARPS_X_PER_TB + warp_x;
const uint32_t in_tile_y = out_tile_y;
const uint32_t in_tile_x = out_tile_x;
// bounds check; uniform branch
if (out_tile_y >= tiles_y || out_tile_x >= tiles_x) {
return;
}
// calculate this warp's input base pointer
constexpr uint32_t in_x_stride = sizeof(float);
const void* const warp_src = in + in_tile_y * in_y_stride + in_tile_x * in_x_stride;
// load scaling factor for this warp's 128x128 tile
uint32_t sf = *reinterpret_cast<const uint32_t*>(warp_src);
// broadcast it to four scaling factors for 1x32 tiles
sf = (sf << 1) | (sf >> 7);
sf = sf | (sf >> 16);
// broadcast it to sixteen scaling factors for 1x32 tiles
const uint4 sf4{sf, sf, sf, sf};
// store it cooperatively for 512 1x32 tiles in a 128x128 tile
constexpr uint32_t out_x_stride = 512;
void* const warp_dst = out + out_tile_y * out_y_stride + out_tile_x * out_x_stride;
reinterpret_cast<uint4*>(warp_dst)[lane] = sf4;
}
void launch_kernel(const void* const in, void* const out, uint32_t data_rows, uint32_t data_cols,
cudaStream_t stream) {
NVTE_CHECK(is_aligned_ptr(in, alignof(float)), "Input scaling factor pointer must be aligned to ",
alignof(float), " bytes");
NVTE_CHECK(is_aligned_ptr(out, alignof(uint4)),
"Output scaling factor pointer must be aligned to ", alignof(uint4), " bytes");
const uint32_t tiles_x = DIVUP(data_cols, 128u);
const uint32_t tiles_y = DIVUP(data_rows, 128u);
const dim3 grid_dim{DIVUP(tiles_x, WARPS_X_PER_TB), DIVUP(tiles_y, WARPS_Y_PER_TB), 1};
const dim3 block_dim{WARP_SIZE, WARPS_Y_PER_TB, WARPS_X_PER_TB};
// Each 128x128 tile in the data corresponds to a 1x1 tile in the input scales
// and a 128x4 tile in the output scales.
const uint32_t input_scale_inv_cols = DIVUP(data_cols, 512u) * 4;
const uint32_t output_scale_inv_cols = tiles_x * 128 * 4;
const uint32_t in_y_stride = input_scale_inv_cols * sizeof(float);
const uint32_t out_y_stride = output_scale_inv_cols * sizeof(uint8_t);
swizzle_block_scaling_2d_to_mxfp8_scaling_factors_kernel<<<grid_dim, block_dim, 0, stream>>>(
in, out, tiles_x, tiles_y, in_y_stride, out_y_stride);
}
} // namespace swizzle_kernel_2d
void swizzle_block_scaling_to_mxfp8_scaling_factors(const Tensor* input, Tensor* output,
cudaStream_t stream) {
// Do nothing if tensor is empty
if (input->data.numel() == 0) {
return;
}
CheckInputTensor(*input, "block_scaling_scaling_factor_input");
CheckInputTensor(*output, "mxfp8_scaling_factor_output");
const NVTEScalingMode scaling_mode = input->scaling_mode;
NVTE_CHECK(scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D,
"Input tensor must be a block scaling tensor");
NVTE_CHECK(output->scaling_mode == NVTE_MXFP8_1D_SCALING,
"Output tensor must be an mxfp8 tensor");
NVTE_CHECK(input->data.dtype == transformer_engine::DType::kFloat8E4M3 ||
input->data.dtype == transformer_engine::DType::kFloat8E5M2,
"Input data must have FP8E4M3 or FP8E5M2 dtype to be compatible with MXFP8");
NVTE_CHECK(output->data.dtype == input->data.dtype,
"Output data must have the same dtype as input data");
NVTE_CHECK(input->scale_inv.dtype == DType::kFloat32, "Input must have FP32 scaling factors");
NVTE_CHECK(output->scale_inv.dtype == DType::kFloat8E8M0,
"Output must have E8M0 scaling factors");
NVTE_CHECK(input->data.dptr != nullptr, "Input must have rowwise data");
NVTE_CHECK(output->data.dptr == input->data.dptr, "Output must share data with input");
NVTE_CHECK(input->scale_inv.dptr != nullptr, "Input must have rowwise scaling factors");
NVTE_CHECK(output->scale_inv.dptr != nullptr, "Output must have rowwise scaling factors");
NVTE_CHECK(input->data.shape.size() == 2, "Input data must be a matrix");
NVTE_CHECK(output->data.shape == input->data.shape,
"Output data must have the same shape as input data");
NVTE_CHECK(input->scale_inv.shape.size() == 2, "Input scaling factors must be a matrix");
NVTE_CHECK(output->scale_inv.shape.size() == 2, "Output scaling factors must be a matrix");
const size_t data_rows = input->data.shape[0];
const size_t data_cols = input->data.shape[1];
const size_t input_scale_inv_rows = input->scale_inv.shape[0];
const size_t input_scale_inv_cols = input->scale_inv.shape[1];
const size_t output_scale_inv_rows = output->scale_inv.shape[0];
const size_t output_scale_inv_cols = output->scale_inv.shape[1];
NVTE_CHECK(output_scale_inv_rows == DIVUP<size_t>(data_rows, 128) * 128,
"Expected the output scaling factor matrix to have ",
DIVUP<size_t>(data_rows, 128) * 128, " rows, but it has ", output_scale_inv_rows,
" rows instead.");
NVTE_CHECK(output_scale_inv_cols == DIVUP<size_t>(data_cols, 128) * 4,
"Expected the output scaling factor matrix to have ",
DIVUP<size_t>(data_cols, 128) * 4, " columns, but it has ", output_scale_inv_cols,
" columns instead.");
if (scaling_mode == NVTE_BLOCK_SCALING_1D) {
NVTE_CHECK(input_scale_inv_rows == DIVUP<size_t>(data_cols, 128),
"Expected the input scaling factor matrix to have ", DIVUP<size_t>(data_cols, 128),
" rows, but it has ", input_scale_inv_rows, " rows instead.");
NVTE_CHECK(input_scale_inv_cols == DIVUP<size_t>(data_rows, 4) * 4,
"Expected the input scaling factor matrix to have ", DIVUP<size_t>(data_rows, 4) * 4,
" columns, but it has ", input_scale_inv_cols, " columns instead.");
swizzle_kernel_1d::launch_kernel(input->scale_inv.dptr, output->scale_inv.dptr, data_rows,
data_cols, stream);
} else { // scaling_mode == NVTE_BLOCK_SCALING_2D
NVTE_CHECK(input_scale_inv_rows == DIVUP<size_t>(data_rows, 128),
"Expected the input scaling factor matrix to have ", DIVUP<size_t>(data_rows, 128),
" rows, but it has ", input_scale_inv_rows, " rows instead.");
NVTE_CHECK(input_scale_inv_cols == DIVUP<size_t>(data_cols, 512) * 4,
"Expected the input scaling factor matrix to have ",
DIVUP<size_t>(data_cols, 512) * 4, " columns, but it has ", input_scale_inv_cols,
" columns instead.");
swizzle_kernel_2d::launch_kernel(input->scale_inv.dptr, output->scale_inv.dptr, data_rows,
data_cols, stream);
}
}
} // namespace transformer_engine
void nvte_swizzle_block_scaling_to_mxfp8_scaling_factors(const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_swizzle_block_scaling_to_mxfp8_scaling_factors);
using namespace transformer_engine;
swizzle_block_scaling_to_mxfp8_scaling_factors(convertNVTETensorCheck(input),
convertNVTETensorCheck(output), stream);
}
......@@ -64,6 +64,10 @@ std::string to_string(const NVTEScalingMode &mode) {
return "NVTE_DELAYED_TENSOR_SCALING";
case NVTE_MXFP8_1D_SCALING:
return "NVTE_MXFP8_1D_SCALING";
case NVTE_BLOCK_SCALING_1D:
return "NVTE_BLOCK_SCALING_1D";
case NVTE_BLOCK_SCALING_2D:
return "NVTE_BLOCK_SCALING_2D";
case NVTE_NVFP4_1D_SCALING:
return "NVTE_NVFP4_1D_SCALING";
case NVTE_INVALID_SCALING:
......
......@@ -14,6 +14,7 @@
#include "common/common.h"
#include "common/recipe/recipe_common.cuh"
#include "common/util/cuda_runtime.h"
#include "common/util/ptx.cuh"
#include "common/utils.cuh"
......@@ -485,6 +486,11 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
NVTE_API_CALL(quantize_transpose_square_blockwise);
checkCuDriverContext(stream);
if (transformer_engine::cuda::sm_arch() >= 100) {
NVTE_CHECK(pow_2_scale, "On Blackwell and newer, the FP8 block scaling recipe is emulated ",
"with MXFP8, which requires using power of two scaling factors.");
}
NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape.");
const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u;
size_t num_rows = 1;
......
......@@ -17,6 +17,7 @@
#include "common/common.h"
#include "common/recipe/recipe_common.cuh"
#include "common/transpose/cast_transpose.h"
#include "common/util/cuda_runtime.h"
#include "common/utils.cuh"
namespace transformer_engine {
......@@ -529,6 +530,11 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
cudaStream_t stream) {
NVTE_API_CALL(quantize_transpose_vector_blockwise);
if (transformer_engine::cuda::sm_arch() >= 100) {
NVTE_CHECK(pow2_scale, "On Blackwell and newer, the FP8 block scaling recipe is emulated ",
"with MXFP8, which requires using power of two scaling factors.");
}
const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u;
size_t num_elements = row_length;
size_t num_rows = 1;
......
......@@ -104,6 +104,10 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
const bool low_precision =
detail::is_low_precision(A_tensor.dtype()) || detail::is_low_precision(B_tensor.dtype());
const bool fp8_block_scaling = A_tensor.scaling_mode() == NVTE_BLOCK_SCALING_1D ||
A_tensor.scaling_mode() == NVTE_BLOCK_SCALING_2D ||
B_tensor.scaling_mode() == NVTE_BLOCK_SCALING_1D ||
B_tensor.scaling_mode() == NVTE_BLOCK_SCALING_2D;
// Check tensor dimensions
const auto& A_shape = A_tensor.shape();
......@@ -235,6 +239,19 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
swizzled_scale_inverses_list.emplace_back(
std::move(swizzle_scaling_factors(B_tensor, !transb)));
// Emulate the FP8 block scaling recipe with MXFP8 on Blackwell and newer
// as it is not natively supported by cublasLt
if (fp8_block_scaling && transformer_engine::cuda::sm_arch() >= 100) {
// Convert tensors to mxfp8 and swizzle their scaling factors
swizzled_scale_inverses_list.emplace_back(
std::move(convert_block_scaling_to_mxfp8_tensor(A_tensor, transa)));
swizzled_scale_inverses_list.emplace_back(
std::move(convert_block_scaling_to_mxfp8_tensor(B_tensor, !transb)));
// Use TN GEMM to avoid having to transpose data.
transa = true;
transb = false;
}
if (comm_overlap) {
// Prepare extra output tensor
TensorWrapper extra_output_tensor;
......@@ -379,15 +396,6 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
std::vector<at::Tensor> bias, DType bias_type, bool single_output,
std::vector<at::Tensor> pre_gelu_out, bool grad, std::vector<at::Tensor> workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count) {
std::vector<NVTETensor> te_A_vector, te_B_vector, te_D_vector, te_bias_vector,
te_pre_gelu_out_vector, te_workspace_vector;
std::vector<TensorWrapper> te_A_wrappers, te_B_wrappers, wrappers;
std::vector<at::Tensor> D_vectors;
auto none = py::none();
std::vector<size_t> single_output_begins;
std::vector<size_t> single_output_ends;
if (single_output && D == std::nullopt) {
NVTE_ERROR("not implemented, D should be allocated for single output case.");
}
......@@ -397,6 +405,10 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
output_data_ptr = (*D)[0].data_ptr();
}
const auto none = py::none();
std::vector<TensorWrapper> te_A_wrappers, te_B_wrappers, te_D_wrappers, te_bias_wrappers,
te_pre_gelu_out_wrappers;
std::vector<at::Tensor> D_vectors;
for (size_t i = 0; i < A.size(); i++) {
auto te_A = makeTransformerEngineTensor(A[i], none);
auto te_B = makeTransformerEngineTensor(B[i], none);
......@@ -462,29 +474,72 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
te_pre_gelu_out =
makeTransformerEngineTensor(get_data_ptr(pre_gelu_out[i]), gelu_shape, gelu_type);
te_A_vector.emplace_back(te_A.data());
te_B_vector.emplace_back(te_B.data());
te_D_vector.emplace_back(te_D.data());
te_bias_vector.emplace_back(te_bias.data());
te_pre_gelu_out_vector.emplace_back(te_pre_gelu_out.data());
te_A_wrappers.emplace_back(std::move(te_A));
te_B_wrappers.emplace_back(std::move(te_B));
wrappers.emplace_back(std::move(te_D));
wrappers.emplace_back(std::move(te_bias));
wrappers.emplace_back(std::move(te_pre_gelu_out));
te_D_wrappers.emplace_back(std::move(te_D));
te_bias_wrappers.emplace_back(std::move(te_bias));
te_pre_gelu_out_wrappers.emplace_back(std::move(te_pre_gelu_out));
}
// Keep the swizzled scaling factor tensors alive during the GEMM.
std::vector<std::optional<at::Tensor>> swizzled_scale_inverses_list;
// Optionally swizzle the scaling factors
// Keep the swizzled scaling factor tensors alive during the GEMMs.
auto swizzled_scale_inv_A = multi_tensor_swizzle_scaling_factors(te_A_wrappers, transa);
auto swizzled_scale_inv_B = multi_tensor_swizzle_scaling_factors(te_B_wrappers, !transb);
swizzled_scale_inverses_list.emplace_back(
multi_tensor_swizzle_scaling_factors(te_A_wrappers, transa));
swizzled_scale_inverses_list.emplace_back(
multi_tensor_swizzle_scaling_factors(te_B_wrappers, !transb));
// Emulate the FP8 block scaling recipe with MXFP8 on Blackwell and newer
// as it is not natively supported by cublasLt
if (transformer_engine::cuda::sm_arch() >= 100) {
// Check if is using FP8 block scaling
bool exists_tensor_using_fp8_block_scaling = false;
bool exists_tensor_not_using_fp8_block_scaling = false;
for (const auto& tensor_wrappers : {&te_A_wrappers, &te_B_wrappers}) {
for (const TensorWrapper& tensor : *tensor_wrappers) {
const NVTEScalingMode scaling_mode = tensor.scaling_mode();
if (scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D)
exists_tensor_using_fp8_block_scaling = true;
else
exists_tensor_not_using_fp8_block_scaling = true;
}
}
if (exists_tensor_using_fp8_block_scaling) {
NVTE_CHECK(!exists_tensor_not_using_fp8_block_scaling,
"Either all tensors or no tensor must be FP8 block scaling tensors");
// Convert tensors to mxfp8 and swizzle their scaling factors
for (TensorWrapper& A_tensor : te_A_wrappers) {
swizzled_scale_inverses_list.emplace_back(
convert_block_scaling_to_mxfp8_tensor(A_tensor, transa));
}
for (TensorWrapper& B_tensor : te_B_wrappers) {
swizzled_scale_inverses_list.emplace_back(
convert_block_scaling_to_mxfp8_tensor(B_tensor, !transb));
}
// Use TN GEMM to avoid having to transpose data.
transa = true;
transb = false;
}
}
std::vector<NVTETensor> te_A_vector, te_B_vector, te_D_vector, te_bias_vector,
te_pre_gelu_out_vector;
for (size_t i = 0; i < te_A_wrappers.size(); i++) {
te_A_vector.emplace_back(te_A_wrappers[i].data());
te_B_vector.emplace_back(te_B_wrappers[i].data());
te_D_vector.emplace_back(te_D_wrappers[i].data());
te_bias_vector.emplace_back(te_bias_wrappers[i].data());
te_pre_gelu_out_vector.emplace_back(te_pre_gelu_out_wrappers[i].data());
}
std::vector<NVTETensor> te_workspace_vector;
std::vector<TensorWrapper> te_workspace_wrappers;
for (size_t i = 0; i < workspace.size(); i++) {
auto wsp = makeTransformerEngineTensor(workspace[i].data_ptr(),
std::vector<size_t>{workspaceSize}, DType::kByte);
te_workspace_vector.emplace_back(wsp.data());
wrappers.emplace_back(std::move(wsp));
te_workspace_wrappers.emplace_back(std::move(wsp));
}
// For now, we only have multi-stream cublas backend.
......
......@@ -7,6 +7,7 @@
#include "util.h"
#include "common.h"
#include "common/common.h"
std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrapper& input,
bool rowwise) {
......@@ -177,3 +178,72 @@ std::optional<at::Tensor> multi_tensor_swizzle_scaling_factors(
return buffer;
}
at::Tensor convert_block_scaling_to_mxfp8_tensor(transformer_engine::TensorWrapper& input,
bool rowwise) {
using namespace transformer_engine::pytorch;
using transformer_engine::DIVUP;
// Check input tensor
const NVTEScalingMode scaling_mode = input.scaling_mode();
NVTE_CHECK(scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D,
"Input tensor must be a block scaling tensor");
// Get tensor data
NVTEBasicTensor data;
size_t data_flat_first_dim = 1;
size_t data_flat_last_dim = 1;
if (rowwise) {
data = input.get_rowwise_data();
for (int i = 0; i < data.shape.ndim - 1; ++i) {
data_flat_first_dim *= data.shape.data[i];
}
data_flat_last_dim = data.shape.data[data.shape.ndim - 1];
} else {
data = input.get_columnwise_data();
data_flat_first_dim = data.shape.data[0];
for (int i = 1; i < data.shape.ndim; ++i) {
data_flat_last_dim *= data.shape.data[i];
}
}
NVTEShape data_shape{};
data_shape.data[0] = data_flat_first_dim;
data_shape.data[1] = data_flat_last_dim;
data_shape.ndim = 2;
// Recreate input tensor with rowwise usage
transformer_engine::TensorWrapper input_cu(scaling_mode);
input_cu.set_rowwise_data(data.data_ptr, input.dtype(), data_shape);
const NVTEBasicTensor scale_inv =
rowwise ? input.get_rowwise_scale_inv() : input.get_columnwise_scale_inv();
input_cu.set_rowwise_scale_inv(
scale_inv.data_ptr, static_cast<transformer_engine::DType>(scale_inv.dtype), scale_inv.shape);
// Create output tensor
transformer_engine::TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING);
output_cu.set_rowwise_data(data.data_ptr, input.dtype(), data_shape);
// Output swizzled mxfp8 scaling factor dimensions
const size_t swizzled_scale_inv_first_dim = DIVUP<size_t>(data_flat_first_dim, 128) * 128;
const size_t swizzled_scale_inv_last_dim = DIVUP<size_t>(data_flat_last_dim, 128) * 4;
// Allocate memory for swizzled mxfp8 scaling factors
const auto options = at::TensorOptions().dtype(torch::kByte).device(torch::kCUDA);
at::Tensor swizzled_scale_inv = at::empty(
std::vector<int64_t>{swizzled_scale_inv_first_dim, swizzled_scale_inv_last_dim}, options);
// Set rowwise scaling factors on output
void* const swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0);
NVTEShape swizzled_scale_inv_shape{};
swizzled_scale_inv_shape.data[0] = swizzled_scale_inv_first_dim;
swizzled_scale_inv_shape.data[1] = swizzled_scale_inv_last_dim;
swizzled_scale_inv_shape.ndim = 2;
output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0,
swizzled_scale_inv_shape);
// Convert scaling factors from FP8 block scaling GEMM_READY format to mxfp8 swizzled format
nvte_swizzle_block_scaling_to_mxfp8_scaling_factors(input_cu.data(), output_cu.data(),
at::cuda::getCurrentCUDAStream());
// Set the input tensor to be the converted mxfp8 tensor and return the swizzled scaling factor
// for it to be kept alive during the GEMM
input = std::move(output_cu);
return swizzled_scale_inv;
}
......@@ -27,4 +27,16 @@ std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrap
std::optional<at::Tensor> multi_tensor_swizzle_scaling_factors(
std::vector<transformer_engine::TensorWrapper> &inputs, bool rowwise);
/*! \brief Convert a block scaling tensor to an mxfp8 tensor in-place.
*
* If rowwise==false, the columnwise data will be reinterpreted as rowwise data to avoid
* transposing it in memory. Due to differences in how block scaling and mxfp8 store data,
* this requires the calling code to treat the output tensor as having been tranposed in this case.
*
* Returns the swizzled scaling factor of the converted mxfp8 tensor.
* The returned swizzled scaling factor tensor should be kept alive during the GEMM.
*/
at::Tensor convert_block_scaling_to_mxfp8_tensor(transformer_engine::TensorWrapper &input,
bool rowwise);
#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_
......@@ -1015,12 +1015,8 @@ def _post_process_fp8_blockwise_gather(
if out._is_gemm_ready_format():
return out
needs_columnwise_data_transpose = (
quantizer is not None and quantizer.columnwise_usage and not is_non_tn_fp8_gemm_supported()
)
need_rowwise_scale_transpose = (
quantizer is not None and quantizer.rowwise_usage and not is_non_tn_fp8_gemm_supported()
)
needs_columnwise_data_transpose = quantizer is not None and quantizer.columnwise_usage
need_rowwise_scale_transpose = quantizer is not None and quantizer.rowwise_usage
# CuBLAS requires transpose of the scale inv tensor, suppose orig input is 256x1024
# columnwise compact format means doing 128x1 quantization of it
......
......@@ -64,13 +64,12 @@ def check_nvfp4_support() -> Tuple[bool, str]:
def check_fp8_block_scaling_support() -> Tuple[bool, str]:
"""Return if fp8 block scaling support is available"""
if (
get_device_compute_capability() >= (9, 0)
and get_device_compute_capability() < (10, 0)
and float(torch.version.cuda) >= 12.9
):
if get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.9:
return True, ""
return False, "FP8 block scaled GEMM requires Hopper and CUDA >= 12.9."
return (
False,
"FP8 block scaled GEMM requires compute capability 9.0 or higher and CUDA >= 12.9.",
)
def check_recipe_support(recipe: Recipe) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment