"git@developer.sourcefind.cn:gaoqiong/pybind11.git" did not exist on "a6c6fd8dfe7d7ec092e0179e3203f3da9fb05fa5"
Commit b944277c authored by wenjh's avatar wenjh
Browse files

[Blockwise] Add support block_len=64 support



Add env to chose blocklen of blockwise quantize.
Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>

Fix pytest of blockwise error
Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>

Resolve new api in  int8 gemm test
Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>

Fix incorrect launch parm
Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>

Fix 1D blockwise(64) acc error
Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent 251dcc7e
......@@ -25,7 +25,11 @@ struct QuantizationOptions {
size_t block_scaling_dim = 2u;
};
#ifdef __HIP_PLATFORM_AMD__
size_t kBlockLen = static_cast<size_t>(blockwise_fp8_block_len());
#else
constexpr size_t kBlockLen = 128;
#endif
enum ProcessingMethod {
CAST_ONLY,
......@@ -80,8 +84,13 @@ template <typename InputType, typename OutputType>
void ref_quantize(const ProcessingMethod processing_method, const InputType* input,
const std::pair<size_t, size_t>& input_hw, OutputType* output, float* scale_inv,
OutputType* output_t, float* scale_inv_t, const QuantizationOptions& opts) {
#ifdef __HIP_PLATFORM_AMD__
size_t kBlockLenX = kBlockLen;
size_t kBlockLenY = kBlockLen;
#else
constexpr size_t kBlockLenX = kBlockLen;
constexpr size_t kBlockLenY = kBlockLen;
#endif
auto quantize_element = [](InputType element, float qscale) -> OutputType {
// Scale in FP32 and cast result to nearest FP8.
......@@ -157,7 +166,11 @@ void ref_quantize_onedimensional_blocks(const ProcessingMethod processing_method
float input_type_max_val = Quantized_Limits<InputType>::max();
float quant_type_max_val = Quantized_Limits<OutputType>::max();
#ifdef __HIP_PLATFORM_AMD__
size_t kBlockLenX = kBlockLen;
#else
constexpr size_t kBlockLenX = kBlockLen;
#endif
auto quantize_element = [](InputType element, float qscale) -> OutputType {
// Scale in FP32 and cast result to nearest FP8.
......
......@@ -168,13 +168,13 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
scale_inv_meta ret_rowwise, ret_colwise;
{
auto scale_dim_0 = DIVUP(first_dim, static_cast<size_t>(128));
auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast<size_t>(128)), 4) * 4;
auto scale_dim_0 = DIVUP(first_dim, static_cast<size_t>(blockwise_fp8_block_len()));
auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast<size_t>(blockwise_fp8_block_len())), 4) * 4;
ret_rowwise.shape = {scale_dim_0, scale_dim_1};
}
{
auto scale_dim_0 = DIVUP(last_dim, static_cast<size_t>(128));
auto scale_dim_1 = DIVUP(DIVUP(first_dim, static_cast<size_t>(128)), 4) * 4;
auto scale_dim_0 = DIVUP(last_dim, static_cast<size_t>(blockwise_fp8_block_len()));
auto scale_dim_1 = DIVUP(DIVUP(first_dim, static_cast<size_t>(blockwise_fp8_block_len())), 4) * 4;
ret_colwise.shape = {scale_dim_0, scale_dim_1};
}
ret_rowwise.type = DType::kFloat32;
......@@ -194,12 +194,12 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
scale_inv_meta ret_rowwise, ret_colwise;
{
auto scale_dim_0 = DIVUP(last_dim, static_cast<size_t>(128));
auto scale_dim_0 = DIVUP(last_dim, static_cast<size_t>(blockwise_fp8_block_len()));
auto scale_dim_1 = DIVUP(first_dim, 4) * 4;
ret_rowwise.shape = {scale_dim_0, scale_dim_1};
}
{
auto scale_dim_0 = DIVUP(first_dim, static_cast<size_t>(128));
auto scale_dim_0 = DIVUP(first_dim, static_cast<size_t>(blockwise_fp8_block_len()));
auto scale_dim_1 = DIVUP(last_dim, 4) * 4;
ret_colwise.shape = {scale_dim_0, scale_dim_1};
}
......
......@@ -22,6 +22,18 @@
namespace test {
using namespace transformer_engine;
inline int blockwise_fp8_block_len() {
const char *env = std::getenv("NVTE_BLOCKWISE_FP8_BLOCK_LEN");
if (env == nullptr || env[0] == '\0') {
return 128;
}
int value;
std::istringstream iss(env);
iss >> value;
NVTE_CHECK(iss, "Invalid environment variable value");
return value;
}
template <size_t i>
struct BytesToType {};
......
......@@ -8,6 +8,7 @@ import torch
import triton
import triton.language as tl
from torch.utils.cpp_extension import IS_HIP_EXTENSION
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
@triton.jit
......@@ -135,7 +136,7 @@ class CuBLASRefBlockwiseGemm:
N, K_w = qw.shape
assert K == K_w, "K dimension mismatch between qx and qw"
tile_len = 128
tile_len = blockwise_fp8_block_len
# Calculate grid sizes without padding
grid_m = (M + tile_len - 1) // tile_len
grid_n = (N + tile_len - 1) // tile_len
......
......@@ -7,7 +7,7 @@ import math
import torch
from typing import Optional, Protocol, Tuple
from references.quantize_scale_calc import scale_from_amax_tensor
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
@dataclasses.dataclass()
class QuantizeResult:
......@@ -277,7 +277,7 @@ class BlockwiseQuantizerReference:
return_transpose: bool = False,
eps: float = 0.0,
pow_2_scales: bool = False,
quant_tile_shape: Tuple[int, int] = (128, 128),
quant_tile_shape: Tuple[int, int] = (blockwise_fp8_block_len, blockwise_fp8_block_len),
) -> QuantizeResult:
# sanity checks
assert x.dim() == 2
......@@ -293,7 +293,7 @@ class BlockwiseQuantizerReference:
torch.int8,
), "Unsupported quant dtype."
assert quant_tile_shape in ((1, 128), (128, 128))
assert quant_tile_shape in ((1, blockwise_fp8_block_len), (blockwise_fp8_block_len, blockwise_fp8_block_len))
if quant_tile_shape[0] == 1:
# Quantize row-wise
return self.scale_munger.munge_scale_shapes_for_backend(
......
......@@ -8,7 +8,7 @@ import transformer_engine as te
import transformer_engine_torch as tex
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.fp8 import (FP8GlobalStateManager, blockwise_fp8_block_len)
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
Float8BlockQuantizer,
Float8BlockwiseQTensor,
......@@ -77,8 +77,9 @@ def cublas_gemm_fp8_blockwise_case(
assert not (use_bias and use_grad), "Bias grad not supported by GEMM"
# Set quantize_op and quantization parameters
x_quant_tile_shape = (1, 128) if is_x_1d_scaled else (128, 128)
w_quant_tile_shape = (1, 128) if is_w_1d_scaled else (128, 128)
block_len = blockwise_fp8_block_len
x_quant_tile_shape = (1, block_len) if is_x_1d_scaled else (block_len, block_len)
w_quant_tile_shape = (1, block_len) if is_w_1d_scaled else (block_len, block_len)
x_block_scaling_dim = 1 if is_x_1d_scaled else 2
w_block_scaling_dim = 1 if is_w_1d_scaled else 2
x_te_dtype = TE_DType[x_dtype]
......@@ -247,8 +248,9 @@ def cublas_gemm_test_constraint_enforced(
out = None
# Set quantize_op and quantization parameters
x_quant_tile_shape = (1, 128) if is_x_1d_scaled else (128, 128)
w_quant_tile_shape = (1, 128) if is_w_1d_scaled else (128, 128)
block_len = blockwise_fp8_block_len
x_quant_tile_shape = (1, block_len) if is_x_1d_scaled else (block_len, block_len)
w_quant_tile_shape = (1, block_len) if is_w_1d_scaled else (block_len, block_len)
x_block_scaling_dim = 1 if is_x_1d_scaled else 2
w_block_scaling_dim = 1 if is_w_1d_scaled else 2
x_te_dtype = TE_DType[x_dtype]
......
......@@ -10,7 +10,7 @@ import pytest
import torch
import transformer_engine as te
import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.fp8 import (FP8GlobalStateManager, blockwise_fp8_block_len)
from transformer_engine.common.recipe import Float8BlockScaling
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
......@@ -99,9 +99,9 @@ def check_quantization_block_tiling_versus_reference(
tile_size: Tuple[int, int],
) -> None:
te_dtype = TE_DType[quant_dtype]
if tile_size == (1, 128):
if tile_size in ((1, 128), (1, 64)):
block_scaling_dim = 1
elif tile_size == (128, 128):
elif tile_size in ((128, 128), (64, 64)):
block_scaling_dim = 2
else:
raise ValueError("Non support tile size")
......@@ -214,7 +214,7 @@ def check_quantization_block_tiling_versus_reference(
"return_transpose", [True, False], ids=["quantize_transpose", "quantize_only"]
)
@pytest.mark.parametrize("pow_2_scales", [True], ids=["pow2scales"])
@pytest.mark.parametrize("tile_size", [(1, 128), (128, 128)], ids=["1DTile", "2DTile"])
@pytest.mark.parametrize("tile_size", [(1, 128), (128, 128), (1, 64), (64, 64)], ids=["1D128Tile", "2D128Tile", "1D64Tile", "2D64Tile"])
def test_quantization_block_tiling_versus_reference(
x_dtype: torch.dtype,
M: int,
......@@ -225,6 +225,8 @@ def test_quantization_block_tiling_versus_reference(
pow_2_scales: bool,
tile_size: Tuple[int, int],
) -> None:
if blockwise_fp8_block_len != tile_size[1]:
pytest.skip("Block len of blockwise is skipped by env.")
check_quantization_block_tiling_versus_reference(
x_dtype, M, N, quant_dtype, eps, return_transpose, pow_2_scales, tile_size
)
......@@ -249,7 +251,7 @@ def test_quantization_block_tiling_versus_reference(
"return_transpose", [True, False], ids=["quantize_transpose", "quantize_only"]
)
@pytest.mark.parametrize("pow_2_scales", [False], ids=["fp32scales"])
@pytest.mark.parametrize("tile_size", [(1, 128), (128, 128)], ids=["1DTile", "2DTile"])
@pytest.mark.parametrize("tile_size", [(1, 128), (128, 128), (1, 64), (64, 64)], ids=["1D128Tile", "2D128Tile", "1D64Tile", "2D64Tile"])
def test_quantization_block_tiling_versus_reference_fp32_scales(
x_dtype: torch.dtype,
M: int,
......@@ -260,6 +262,8 @@ def test_quantization_block_tiling_versus_reference_fp32_scales(
pow_2_scales: bool,
tile_size: Tuple[int, int],
) -> None:
if blockwise_fp8_block_len != tile_size[1]:
pytest.skip("Block len of blockwise is skipped by env.")
check_quantization_block_tiling_versus_reference(
x_dtype, M, N, quant_dtype, eps, return_transpose, pow_2_scales, tile_size
)
......@@ -277,7 +281,7 @@ def test_quantization_block_tiling_versus_reference_fp32_scales(
@pytest.mark.parametrize("quant_dtype", [torch.int8, torch.float8_e4m3fn, torch.float8_e5m2], ids=str)
@pytest.mark.parametrize("eps", [0], ids=["eps_0"])
@pytest.mark.parametrize("pow_2_scales", [True, False], ids=["pow2scales", "fp32scales"])
@pytest.mark.parametrize("tile_size", [(128, 128)])
@pytest.mark.parametrize("tile_size", [(128, 128), (64, 64)], ids=["2D128Tile", "2D64Tile"])
@pytest.mark.parametrize("extrema_high", [False, True], ids=["zeros", "maxes"])
def test_quantization_block_tiling_extrema_versus_reference(
x_dtype: torch.dtype,
......@@ -291,10 +295,12 @@ def test_quantization_block_tiling_extrema_versus_reference(
) -> None:
# This test runs a single tile through a quantizer as a way to test
# branch coverage of scale computation.
if blockwise_fp8_block_len != tile_size[1]:
pytest.skip("Block len of blockwise is skipped by env.")
te_dtype = TE_DType[quant_dtype]
if tile_size == (1, 128):
if tile_size in ((1, 128), (1, 64)):
block_scaling_dim = 1
elif tile_size == (128, 128):
elif tile_size in ((128, 128), (64, 64)):
block_scaling_dim = 2
else:
raise ValueError("Non support tile size")
......
......@@ -4,7 +4,7 @@ import transformer_engine as te
import transformer_engine_torch as tex
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.fp8 import (FP8GlobalStateManager, blockwise_fp8_block_len)
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
Float8BlockQuantizer,
Float8BlockwiseQTensor,
......@@ -82,8 +82,9 @@ def cublas_gemm_fp8_blockwise_case_fw(
assert not (use_bias and use_grad), "Bias grad not supported by GEMM"
# Set quantize_op and quantization parameters
x_quant_tile_shape = (1, 128) if is_x_1d_scaled else (128, 128)
w_quant_tile_shape = (1, 128) if is_w_1d_scaled else (128, 128)
block_len = blockwise_fp8_block_len
x_quant_tile_shape = (1, block_len) if is_x_1d_scaled else (block_len, block_len)
w_quant_tile_shape = (1, block_len) if is_w_1d_scaled else (block_len, block_len)
x_block_scaling_dim = 1 if is_x_1d_scaled else 2
w_block_scaling_dim = 1 if is_w_1d_scaled else 2
x_te_dtype = TE_DType[x_dtype]
......@@ -196,7 +197,7 @@ def cublas_gemm_fp8_blockwise_case_fw(
ref_scales_w = qw._columnwise_scale_inv if w_columnwise else qw._rowwise_scale_inv
y, _ = w8a8_block_int8_matmul(
qx_data, qw_data, ref_scales_x, ref_scales_w, [128, 128],
qx_data, qw_data, ref_scales_x, ref_scales_w, [block_len, block_len],
output_dtype=out_dtype
)
......@@ -265,8 +266,9 @@ def cublas_gemm_fp8_blockwise_case_bw_xgrad(
assert not (use_bias and use_grad), "Bias grad not supported by GEMM"
# Set quantize_op and quantization parameters
dout_quant_tile_shape = (1, 128) if is_dout_1d_scaled else (128, 128)
w_quant_tile_shape = (1, 128) if is_w_1d_scaled else (128, 128)
block_len = blockwise_fp8_block_len
dout_quant_tile_shape = (1, block_len) if is_dout_1d_scaled else (block_len, block_len)
w_quant_tile_shape = (1, block_len) if is_w_1d_scaled else (block_len, block_len)
dout_block_scaling_dim = 1 if is_dout_1d_scaled else 2
w_block_scaling_dim = 1 if is_w_1d_scaled else 2
dout_te_dtype = TE_DType[dout_dtype]
......@@ -373,7 +375,7 @@ def cublas_gemm_fp8_blockwise_case_bw_xgrad(
ref_scales_w = qw._columnwise_scale_inv if w_columnwise else qw._rowwise_scale_inv
y, _ = w8a8_block_int8_matmul(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [128, 128],
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [block_len, block_len],
output_dtype=dx_dtype
)
......@@ -441,8 +443,9 @@ def cublas_gemm_fp8_blockwise_case_bw_wgrad(
assert not (use_bias and use_grad), "Bias grad not supported by GEMM"
# Set quantize_op and quantization parameters
dout_quant_tile_shape = (1, 128) if is_dout_1d_scaled else (128, 128)
x_quant_tile_shape = (1, 128) if is_x_1d_scaled else (128, 128)
block_len = blockwise_fp8_block_len
dout_quant_tile_shape = (1, block_len) if is_dout_1d_scaled else (block_len, block_len)
x_quant_tile_shape = (1, block_len) if is_x_1d_scaled else (block_len, block_len)
dout_block_scaling_dim = 1 if is_dout_1d_scaled else 2
x_block_scaling_dim = 1 if is_x_1d_scaled else 2
dout_te_dtype = TE_DType[dout_dtype]
......@@ -552,7 +555,8 @@ def cublas_gemm_fp8_blockwise_case_bw_wgrad(
# print(f"ref_scales_dout.shape: {ref_scales_dout.shape}, ref_scales_x.shape: {ref_scales_x.shape}")
y, _ = w8a8_block_int8_matmul_wgrad(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, [128, 128],
qdout_data, qx_data, ref_scales_dout, ref_scales_x, dw.clone() if accumulate else None,
accumulate, [block_len, block_len],
output_dtype=dw_dtype
)
......
......@@ -6,6 +6,7 @@
#ifndef TRANSFORMER_ENGINE_COMMON_COMMON_H_
#define TRANSFORMER_ENGINE_COMMON_COMMON_H_
#include "util/system.h"
#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h>
#endif
......@@ -33,6 +34,10 @@ namespace transformer_engine {
std::string to_string(const DType type);
std::string to_string(const NVTEScalingMode &mode);
inline int blockwise_fp8_block_len() {
return ::transformer_engine::getenv<int>("NVTE_BLOCKWISE_FP8_BLOCK_LEN", 128);
}
inline bool is_tensor_scaling(const NVTEScalingMode &mode) {
return mode == NVTE_DELAYED_TENSOR_SCALING;
}
......
......@@ -14,6 +14,7 @@
namespace transformer_engine {
namespace fp8_block_scaling_recipe {
constexpr int kTileDim64 = 64;
constexpr int kTileDim = 128;
constexpr int kThreadsPerBlock = 256;
......@@ -116,10 +117,10 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
if (h_in_input < h && w_in_input < w && idx_in_input >= start_offset &&
idx_in_input < end_offset) {
float inp = static_cast<float>(input_minus_offset[idx_in_input]) * scale;
if constexpr(std::is_same_v<OType, int8_t>) {
smem[h_in_smem][w_in_smem] = static_cast<OType>(lroundf(fmaxf(-127.0f, fminf(127.0f, inp))));
}
else {
if constexpr (std::is_same_v<OType, int8_t>) {
smem[h_in_smem][w_in_smem] =
static_cast<OType>(lroundf(fmaxf(-127.0f, fminf(127.0f, inp))));
} else {
smem[h_in_smem][w_in_smem] = static_cast<OType>(inp);
}
skip_store = false;
......@@ -175,11 +176,171 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
}
}
template <typename IType>
__global__ void __launch_bounds__(kThreadsPerBlock)
fp8_block_scaling_block_len64_compute_partial_amax_kernel(
const IType *input, float *amax_ptr, const size_t amax_stride_h, const size_t amax_stride_w,
const size_t h, const size_t w, const size_t start_offset, const size_t len) {
constexpr int kThreadsPerWarp = 32;
constexpr int kLoopsPerRow = kTileDim64 / kThreadsPerWarp;
constexpr int kNumWarps = kThreadsPerBlock / kThreadsPerWarp;
constexpr int kLoopsPerCol = kTileDim64 / kNumWarps;
const int tile_col = blockIdx.x;
const int tile_row = blockIdx.y;
const size_t end_offset = start_offset + len;
const IType *input_minus_offset = input - start_offset;
__shared__ float smem[kNumWarps];
float amax = 0.0f;
for (int loop_col = 0; loop_col < kLoopsPerCol; ++loop_col) {
size_t r = tile_row * kTileDim64 + loop_col * kNumWarps + threadIdx.x / kThreadsPerWarp;
for (int loop_row = 0; loop_row < kLoopsPerRow; ++loop_row) {
size_t c =
tile_col * kTileDim64 + loop_row * kThreadsPerWarp + (threadIdx.x % kThreadsPerWarp);
size_t idx = r * w + c;
if (r < h && c < w && idx >= start_offset && idx < end_offset) {
float other_amax = fabs(static_cast<float>(input_minus_offset[idx]));
__builtin_assume(amax >= 0);
__builtin_assume(other_amax >= 0);
amax = fmaxf(amax, other_amax);
}
}
}
for (int delta = kThreadsPerWarp / 2; delta > 0; delta /= 2) {
#ifdef __HIP_PLATFORM_AMD__
float other_amax = __shfl_down(amax, delta, kThreadsPerWarp);
#else
float other_amax = __shfl_down_sync(0xFFFFFFFF, amax, delta);
#endif
__builtin_assume(amax >= 0);
__builtin_assume(other_amax >= 0);
amax = fmaxf(amax, other_amax);
}
if (threadIdx.x % kThreadsPerWarp == 0) {
smem[threadIdx.x / kThreadsPerWarp] = amax;
}
__syncthreads();
if (threadIdx.x == 0) {
for (int i = 0; i < kNumWarps; ++i) {
float other_amax = smem[i];
__builtin_assume(amax >= 0);
__builtin_assume(other_amax >= 0);
amax = fmaxf(amax, other_amax);
}
amax_ptr[tile_row * amax_stride_h + tile_col * amax_stride_w] = amax;
}
}
template <typename IType, typename OType, bool kWidthAligned>
__global__ void __launch_bounds__(kThreadsPerBlock)
fp8_block_scaling_block_len64_partial_cast_kernel(const IType *input, OType *output,
const float *scale_ptr,
const size_t scale_stride_h,
const size_t scale_stride_w, const size_t h,
const size_t w, const size_t start_offset,
const size_t len) {
using transformer_engine::Vec;
static_assert(sizeof(OType) == 1);
constexpr int kNumOutputElemsPerBank = 4 / sizeof(OType);
constexpr int kThreadsPerWarp = 32;
constexpr int kLoopsPerRow = kTileDim64 / kThreadsPerWarp;
constexpr int kNumWarps = kThreadsPerBlock / kThreadsPerWarp;
constexpr int kRowsPerWarp = kTileDim64 / kNumWarps;
__shared__ OType smem[kTileDim64][kTileDim64 + kNumOutputElemsPerBank];
const int tile_w = blockIdx.x;
const int tile_h = blockIdx.y;
const size_t end_offset = start_offset + len;
const IType *input_minus_offset = input - start_offset;
OType *output_minus_offset = output - start_offset;
const float scale = scale_ptr[tile_h * scale_stride_h + tile_w * scale_stride_w];
// Load input data into shared memory
bool skip_store = true;
for (int i = 0; i < kRowsPerWarp; ++i) {
for (int j = 0; j < kLoopsPerRow; ++j) {
const int h_in_smem = threadIdx.x / kThreadsPerWarp * kRowsPerWarp + i;
const int w_in_smem = threadIdx.x % kThreadsPerWarp + kThreadsPerWarp * j;
const int h_in_input = tile_h * kTileDim64 + h_in_smem;
const int w_in_input = tile_w * kTileDim64 + w_in_smem;
const size_t idx_in_input = static_cast<size_t>(h_in_input) * w + w_in_input;
if (h_in_input < h && w_in_input < w && idx_in_input >= start_offset &&
idx_in_input < end_offset) {
float inp = static_cast<float>(input_minus_offset[idx_in_input]) * scale;
if constexpr (std::is_same_v<OType, int8_t>) {
smem[h_in_smem][w_in_smem] =
static_cast<OType>(lroundf(fmaxf(-127.0f, fminf(127.0f, inp))));
} else {
smem[h_in_smem][w_in_smem] = static_cast<OType>(inp);
}
skip_store = false;
}
}
}
for (int delta = kThreadsPerWarp / 2; delta > 0; delta /= 2) {
#ifdef __HIP_PLATFORM_AMD__
bool other_skip_store = __shfl_down(skip_store, delta, kThreadsPerWarp);
#else
bool other_skip_store = __shfl_down_sync(0xFFFFFFFF, skip_store, delta);
#endif
skip_store = skip_store && other_skip_store;
}
#ifdef __HIP_PLATFORM_AMD__
skip_store = __shfl(skip_store, 0, kThreadsPerWarp);
#else
skip_store = __shfl_sync(0xFFFFFFFF, skip_store, 0);
#endif
if (skip_store) {
return;
}
// Store the casted data into the output.
// Note that this store operation might write "out-of-bounds", but it is intentional:
// 1. The "out-of-bounds" here only crosses the boundary of the "local shard" (i.e., the region
// from start_offset to end_offset), not the boundary of the entire output memory. Therefore,
// this out-of-bounds write will not cause illegal memory access.
// 2. We assume that the subsequent all-gather operation happens in-place, so any parts that
// should not be updated here will be overwritten by the all-gather.
// This tricky approach allows us to avoid checking whether each output index falls within
// [start, end), resulting in a significant performance improvement.
Vec<OType, kNumOutputElemsPerBank> vec_output;
for (int i = 0; i < kRowsPerWarp; ++i) {
const int row_in_smem = threadIdx.x / kThreadsPerWarp * kRowsPerWarp + i;
const int col_in_smem = threadIdx.x % kThreadsPerWarp * kNumOutputElemsPerBank;
for (int j = 0; j < kNumOutputElemsPerBank; ++j) {
vec_output.data.elt[j] = smem[row_in_smem][col_in_smem + j];
}
const int row_in_output = tile_h * kTileDim64 + row_in_smem;
const int col_in_output = tile_w * kTileDim64 + col_in_smem;
const size_t idx_in_output = static_cast<size_t>(row_in_output) * w + col_in_output;
if (row_in_output < h) {
if constexpr (kWidthAligned) {
vec_output.store_to(output_minus_offset + idx_in_output);
} else {
int num = min(static_cast<size_t>(kNumOutputElemsPerBank),
static_cast<size_t>(col_in_output < w ? w - col_in_output : 0));
vec_output.store_to_elts(output_minus_offset, idx_in_output, num);
}
}
}
}
void fp8_block_scaling_compute_partial_amax(const Tensor inp, Tensor amax, size_t h, size_t w,
size_t amax_stride_h, size_t amax_stride_w,
size_t start_offset, size_t block_len,
cudaStream_t stream) {
NVTE_CHECK(block_len == 128, "Currently only block_len = 128 is supported");
NVTE_CHECK(block_len == 128 || block_len == 64,
"Currently only block_len = 128 or 64 is supported");
size_t len = inp.numel();
......@@ -187,26 +348,39 @@ void fp8_block_scaling_compute_partial_amax(const Tensor inp, Tensor amax, size_
assert(start_offset < h * w);
assert(start_offset + len <= h * w);
size_t blocks_x = (w + kTileDim - 1) / kTileDim;
size_t blocks_y = (h + kTileDim - 1) / kTileDim;
size_t blocks_x = (w + block_len - 1) / block_len;
size_t blocks_y = (h + block_len - 1) / block_len;
assert(blocks_x <= std::numeric_limits<unsigned int>::max());
assert(blocks_y <= std::numeric_limits<unsigned int>::max());
dim3 grid(blocks_x, blocks_y);
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
inp.dtype(), inp_dtype,
fp8_block_scaling_compute_partial_amax_kernel<inp_dtype>
<<<grid, kThreadsPerBlock, 0, stream>>>(reinterpret_cast<const inp_dtype *>(inp.data.dptr),
reinterpret_cast<float *>(amax.data.dptr),
amax_stride_h, amax_stride_w, h, w, start_offset,
len);)
inp.dtype(), inp_dtype, while (true) {
if (128 == block_len) {
fp8_block_scaling_compute_partial_amax_kernel<inp_dtype>
<<<grid, kThreadsPerBlock, 0, stream>>>(
reinterpret_cast<const inp_dtype *>(inp.data.dptr),
reinterpret_cast<float *>(amax.data.dptr), amax_stride_h, amax_stride_w, h, w,
start_offset, len);
break;
}
if (64 == block_len) {
fp8_block_scaling_block_len64_compute_partial_amax_kernel<inp_dtype>
<<<grid, kThreadsPerBlock, 0, stream>>>(
reinterpret_cast<const inp_dtype *>(inp.data.dptr),
reinterpret_cast<float *>(amax.data.dptr), amax_stride_h, amax_stride_w, h, w,
start_offset, len);
break;
}
})
}
void fp8_block_scaling_partial_cast(const Tensor inp, Tensor out, const Tensor scale, size_t h,
size_t w, size_t scale_stride_h, size_t scale_stride_w,
size_t start_offset, size_t block_len, const DType out_dtype,
cudaStream_t stream) {
NVTE_CHECK(block_len == 128, "Currently only block_len = 128 is supported");
NVTE_CHECK(block_len == 128 || block_len == 64,
"Currently only block_len = 128 or 64 is supported");
size_t len = inp.numel();
......@@ -214,8 +388,8 @@ void fp8_block_scaling_partial_cast(const Tensor inp, Tensor out, const Tensor s
assert(start_offset < h * w);
assert(start_offset + len <= h * w);
size_t blocks_x = (w + kTileDim - 1) / kTileDim;
size_t blocks_y = (h + kTileDim - 1) / kTileDim;
size_t blocks_x = (w + block_len - 1) / block_len;
size_t blocks_y = (h + block_len - 1) / block_len;
assert(blocks_x <= std::numeric_limits<unsigned int>::max());
assert(blocks_y <= std::numeric_limits<unsigned int>::max());
dim3 grid(blocks_x, blocks_y);
......@@ -225,13 +399,27 @@ void fp8_block_scaling_partial_cast(const Tensor inp, Tensor out, const Tensor s
TRANSFORMER_ENGINE_TYPE_SWITCH_8BIT(
out_dtype, fp8_type,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
w % kTileDim == 0, kWidthAligned,
fp8_block_scaling_partial_cast_kernel<inp_dtype, fp8_type, kWidthAligned>
<<<grid, kThreadsPerBlock, 0, stream>>>(
reinterpret_cast<const inp_dtype *>(inp.data.dptr),
reinterpret_cast<fp8_type *>(out.data.dptr),
reinterpret_cast<const float *>(scale.data.dptr), scale_stride_h, scale_stride_w,
h, w, start_offset, len);)))
w % block_len == 0, kWidthAligned, while (true) {
if (128 == block_len) {
fp8_block_scaling_partial_cast_kernel<inp_dtype, fp8_type, kWidthAligned>
<<<grid, kThreadsPerBlock, 0, stream>>>(
reinterpret_cast<const inp_dtype *>(inp.data.dptr),
reinterpret_cast<fp8_type *>(out.data.dptr),
reinterpret_cast<const float *>(scale.data.dptr), scale_stride_h,
scale_stride_w, h, w, start_offset, len);
break;
}
if (64 == block_len) {
fp8_block_scaling_block_len64_partial_cast_kernel<inp_dtype, fp8_type,
kWidthAligned>
<<<grid, kThreadsPerBlock, 0, stream>>>(
reinterpret_cast<const inp_dtype *>(inp.data.dptr),
reinterpret_cast<fp8_type *>(out.data.dptr),
reinterpret_cast<const float *>(scale.data.dptr), scale_stride_h,
scale_stride_w, h, w, start_offset, len);
break;
}
})))
}
} // namespace fp8_block_scaling_recipe
......
......@@ -15,6 +15,7 @@ from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt_wgrad import w8a8_
from ..tensor.quantized_tensor import Quantizer
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ...debug.pytorch.debug_quantization import DebugQuantizer
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
int8_simulation_fp8 = bool(int(os.getenv("NVTE_INT8_SIM_FP8", "0")))
__all__ = [
......@@ -76,7 +77,7 @@ def general_gemm(
ref_scales_w = A._rowwise_scale_inv
y, _ = w8a8_block_int8_matmul(
qx_data, qw_data, ref_scales_x, ref_scales_w, [128, 128],
qx_data, qw_data, ref_scales_x, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype
)
return y, None, None, None
......@@ -92,7 +93,7 @@ def general_gemm(
ref_scales_w = A._columnwise_scale_inv
y, _ = w8a8_block_int8_matmul(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [128, 128],
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype
)
return y, None, None, None
......@@ -108,7 +109,7 @@ def general_gemm(
ref_scales_x = A._columnwise_scale_inv
out, _ = w8a8_block_int8_matmul_wgrad(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate, [128, 128],
qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate, [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype
)
return out, None, None, None
......@@ -243,7 +244,7 @@ def general_grouped_gemm(
seq_len = sum(m_splits) // num_gemms
out[0] = w8a8_block_int8_matmul_batched(
qx_data, qw_data, ref_scales_x, ref_scales_w, out[0].view(num_gemms, seq_len, out[0].size(-1)), [128, 128],
qx_data, qw_data, ref_scales_x, ref_scales_w, out[0].view(num_gemms, seq_len, out[0].size(-1)), [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype
)
return out, bias, gelu_input
......@@ -262,7 +263,7 @@ def general_grouped_gemm(
seq_len = sum(m_splits) // num_gemms
out[0] = w8a8_block_int8_matmul_batched(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, out[0].view(num_gemms, seq_len, out[0].size(-1)), [128, 128],
qdout_data, qw_data, ref_scales_dout, ref_scales_w, out[0].view(num_gemms, seq_len, out[0].size(-1)), [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype
)
return out, bias, gelu_input
......@@ -278,7 +279,7 @@ def general_grouped_gemm(
ref_scales_x = [a._columnwise_scale_inv for a in A]
out = w8a8_block_int8_matmul_wgrad_batched_native(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate, [128, 128],
qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate, [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype
)
return out, bias, gelu_input
......
......@@ -48,6 +48,8 @@
#include <cassert>
#include <cstring>
#include <iostream>
#include <string>
#include <sstream>
#include <memory>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <vector>
......@@ -60,6 +62,18 @@ namespace transformer_engine::pytorch {
// in python we have: dist_group_type = torch.distributed.ProcessGroup
using dist_group_type = c10d::ProcessGroup;
inline int blockwise_fp8_block_len() {
const char *env = std::getenv("NVTE_BLOCKWISE_FP8_BLOCK_LEN");
if (env == nullptr || env[0] == '\0') {
return 128;
}
int value;
std::istringstream iss(env);
iss >> value;
NVTE_CHECK(iss, "Invalid environment variable value");
return value;
}
// Each tensor here is shape (N, ) holding all scaling
// data for a single FP8 block, e.g. LayerNormLinear
class FP8TensorMeta {
......
......@@ -10,7 +10,7 @@ namespace transformer_engine::pytorch {
void fp8_block_scaling_compute_partial_amax(const at::Tensor &tensor, at::Tensor amax, size_t h,
size_t w, size_t start_offset, size_t block_len) {
TORCH_CHECK(block_len == 128, "Currently only block_len = 128 is supported");
TORCH_CHECK(block_len == 128 || block_len == 64, "Currently only block_len = 128 or 64 is supported");
TORCH_CHECK(amax.dim() == 2, "amax must be a 2D tensor");
TORCH_CHECK(amax.scalar_type() == at::ScalarType::Float, "amax must be a float tensor");
TORCH_CHECK(tensor.scalar_type() == at::ScalarType::Float ||
......@@ -28,7 +28,7 @@ void fp8_block_scaling_compute_partial_amax(const at::Tensor &tensor, at::Tensor
void fp8_block_scaling_partial_cast(const at::Tensor &inp, at::Tensor out, const at::Tensor &scale,
size_t h, size_t w, size_t start_offset, size_t block_len,
const transformer_engine::DType out_dtype) {
TORCH_CHECK(block_len == 128, "Currently only block_len = 128 is supported");
TORCH_CHECK(block_len == 128 || block_len == 64, "Currently only block_len = 128 or 64 is supported");
TORCH_CHECK(scale.dim() == 2, "scale must be a 2D tensor");
TORCH_CHECK(scale.scalar_type() == at::ScalarType::Float, "scale must be a float tensor");
TORCH_CHECK(
......
......@@ -297,7 +297,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
size_t k_dim = torch_shape.size() == 0 ? 1u : torch_shape.back();
size_t m_dim = numel / k_dim;
constexpr size_t kBlockLen = 128;
size_t kBlockLen = static_cast<size_t>(blockwise_fp8_block_len());
if (rowwise_usage) {
if (rowwise_data.has_value()) {
......
......@@ -1018,7 +1018,7 @@ def _all_gather_fp8_blockwise(
# Check that quantizer is valid
if quantizer is not None and not isinstance(quantizer, Float8BlockQuantizer):
raise ValueError(f"Got non-FP8 blockwise quantizer ({quantizer.__class__.__name__})")
if not (quantizer.block_scaling_dim == 1 and quantizer.block_len == 128):
if not (quantizer.block_scaling_dim == 1 and (quantizer.block_len == 128 or quantizer.block_len == 64)):
raise NotImplementedError("Only 1D blockwise quantization is supported for allgather")
# Output tensor dims
......
......@@ -28,6 +28,7 @@ from .utils import get_device_compute_capability
from .jit import jit_fuser
from torch.utils.cpp_extension import IS_HIP_EXTENSION
int8_simulation_fp8 = bool(int(os.getenv("NVTE_INT8_SIM_FP8", "0")))
blockwise_fp8_block_len = int(os.getenv("NVTE_BLOCKWISE_FP8_BLOCK_LEN", "128"))
__all__ = ["fp8_autocast", "fp8_model_init"]
......
......@@ -11,6 +11,7 @@ import torch
import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
from ..quantized_tensor import QuantizedTensorBase
......@@ -125,7 +126,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
return torch.permute(columnwise_dq, tuple(permute_dims)).contiguous()
def _dequantize_vectorwise(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor:
block_len = 128
block_len = blockwise_fp8_block_len
q_M, q_K = 1, 1
if self._rowwise_data is not None:
......@@ -178,7 +179,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
"""
Construct plain PyTorch tensor from Float8BlockwiseQTensor
"""
block_len = 128
block_len = blockwise_fp8_block_len
if not self._is_2D_scaled:
return self._dequantize_vectorwise(dtype=dtype)
......
......@@ -14,6 +14,7 @@ from transformer_engine_torch import DType as TE_DType
from ._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc
from ..utils import devices_match, round_up_to_nearest_multiple
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
aten = torch.ops.aten
......@@ -46,7 +47,7 @@ class Float8BlockQuantizer(Quantizer):
) -> None:
super().__init__(rowwise=rowwise, columnwise=columnwise)
self.dtype = tex.DType.kInt8 if int8_simulation_fp8 else fp8_dtype
self.block_len = 128
self.block_len = blockwise_fp8_block_len
self.force_pow_2_scales = force_pow_2_scales
self.amax_epsilon = amax_epsilon
self.block_scaling_dim = block_scaling_dim
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment