Commit b944277c authored by wenjh's avatar wenjh
Browse files

[Blockwise] Add support block_len=64 support



Add env to chose blocklen of blockwise quantize.
Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>

Fix pytest of blockwise error
Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>

Resolve new api in  int8 gemm test
Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>

Fix incorrect launch parm
Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>

Fix 1D blockwise(64) acc error
Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent 251dcc7e
...@@ -25,7 +25,11 @@ struct QuantizationOptions { ...@@ -25,7 +25,11 @@ struct QuantizationOptions {
size_t block_scaling_dim = 2u; size_t block_scaling_dim = 2u;
}; };
#ifdef __HIP_PLATFORM_AMD__
size_t kBlockLen = static_cast<size_t>(blockwise_fp8_block_len());
#else
constexpr size_t kBlockLen = 128; constexpr size_t kBlockLen = 128;
#endif
enum ProcessingMethod { enum ProcessingMethod {
CAST_ONLY, CAST_ONLY,
...@@ -80,8 +84,13 @@ template <typename InputType, typename OutputType> ...@@ -80,8 +84,13 @@ template <typename InputType, typename OutputType>
void ref_quantize(const ProcessingMethod processing_method, const InputType* input, void ref_quantize(const ProcessingMethod processing_method, const InputType* input,
const std::pair<size_t, size_t>& input_hw, OutputType* output, float* scale_inv, const std::pair<size_t, size_t>& input_hw, OutputType* output, float* scale_inv,
OutputType* output_t, float* scale_inv_t, const QuantizationOptions& opts) { OutputType* output_t, float* scale_inv_t, const QuantizationOptions& opts) {
#ifdef __HIP_PLATFORM_AMD__
size_t kBlockLenX = kBlockLen;
size_t kBlockLenY = kBlockLen;
#else
constexpr size_t kBlockLenX = kBlockLen; constexpr size_t kBlockLenX = kBlockLen;
constexpr size_t kBlockLenY = kBlockLen; constexpr size_t kBlockLenY = kBlockLen;
#endif
auto quantize_element = [](InputType element, float qscale) -> OutputType { auto quantize_element = [](InputType element, float qscale) -> OutputType {
// Scale in FP32 and cast result to nearest FP8. // Scale in FP32 and cast result to nearest FP8.
...@@ -157,7 +166,11 @@ void ref_quantize_onedimensional_blocks(const ProcessingMethod processing_method ...@@ -157,7 +166,11 @@ void ref_quantize_onedimensional_blocks(const ProcessingMethod processing_method
float input_type_max_val = Quantized_Limits<InputType>::max(); float input_type_max_val = Quantized_Limits<InputType>::max();
float quant_type_max_val = Quantized_Limits<OutputType>::max(); float quant_type_max_val = Quantized_Limits<OutputType>::max();
#ifdef __HIP_PLATFORM_AMD__
size_t kBlockLenX = kBlockLen;
#else
constexpr size_t kBlockLenX = kBlockLen; constexpr size_t kBlockLenX = kBlockLen;
#endif
auto quantize_element = [](InputType element, float qscale) -> OutputType { auto quantize_element = [](InputType element, float qscale) -> OutputType {
// Scale in FP32 and cast result to nearest FP8. // Scale in FP32 and cast result to nearest FP8.
......
...@@ -168,13 +168,13 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape, ...@@ -168,13 +168,13 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
scale_inv_meta ret_rowwise, ret_colwise; scale_inv_meta ret_rowwise, ret_colwise;
{ {
auto scale_dim_0 = DIVUP(first_dim, static_cast<size_t>(128)); auto scale_dim_0 = DIVUP(first_dim, static_cast<size_t>(blockwise_fp8_block_len()));
auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast<size_t>(128)), 4) * 4; auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast<size_t>(blockwise_fp8_block_len())), 4) * 4;
ret_rowwise.shape = {scale_dim_0, scale_dim_1}; ret_rowwise.shape = {scale_dim_0, scale_dim_1};
} }
{ {
auto scale_dim_0 = DIVUP(last_dim, static_cast<size_t>(128)); auto scale_dim_0 = DIVUP(last_dim, static_cast<size_t>(blockwise_fp8_block_len()));
auto scale_dim_1 = DIVUP(DIVUP(first_dim, static_cast<size_t>(128)), 4) * 4; auto scale_dim_1 = DIVUP(DIVUP(first_dim, static_cast<size_t>(blockwise_fp8_block_len())), 4) * 4;
ret_colwise.shape = {scale_dim_0, scale_dim_1}; ret_colwise.shape = {scale_dim_0, scale_dim_1};
} }
ret_rowwise.type = DType::kFloat32; ret_rowwise.type = DType::kFloat32;
...@@ -194,12 +194,12 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape, ...@@ -194,12 +194,12 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
scale_inv_meta ret_rowwise, ret_colwise; scale_inv_meta ret_rowwise, ret_colwise;
{ {
auto scale_dim_0 = DIVUP(last_dim, static_cast<size_t>(128)); auto scale_dim_0 = DIVUP(last_dim, static_cast<size_t>(blockwise_fp8_block_len()));
auto scale_dim_1 = DIVUP(first_dim, 4) * 4; auto scale_dim_1 = DIVUP(first_dim, 4) * 4;
ret_rowwise.shape = {scale_dim_0, scale_dim_1}; ret_rowwise.shape = {scale_dim_0, scale_dim_1};
} }
{ {
auto scale_dim_0 = DIVUP(first_dim, static_cast<size_t>(128)); auto scale_dim_0 = DIVUP(first_dim, static_cast<size_t>(blockwise_fp8_block_len()));
auto scale_dim_1 = DIVUP(last_dim, 4) * 4; auto scale_dim_1 = DIVUP(last_dim, 4) * 4;
ret_colwise.shape = {scale_dim_0, scale_dim_1}; ret_colwise.shape = {scale_dim_0, scale_dim_1};
} }
......
...@@ -22,6 +22,18 @@ ...@@ -22,6 +22,18 @@
namespace test { namespace test {
using namespace transformer_engine; using namespace transformer_engine;
inline int blockwise_fp8_block_len() {
const char *env = std::getenv("NVTE_BLOCKWISE_FP8_BLOCK_LEN");
if (env == nullptr || env[0] == '\0') {
return 128;
}
int value;
std::istringstream iss(env);
iss >> value;
NVTE_CHECK(iss, "Invalid environment variable value");
return value;
}
template <size_t i> template <size_t i>
struct BytesToType {}; struct BytesToType {};
......
...@@ -8,6 +8,7 @@ import torch ...@@ -8,6 +8,7 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
from torch.utils.cpp_extension import IS_HIP_EXTENSION from torch.utils.cpp_extension import IS_HIP_EXTENSION
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
@triton.jit @triton.jit
...@@ -135,7 +136,7 @@ class CuBLASRefBlockwiseGemm: ...@@ -135,7 +136,7 @@ class CuBLASRefBlockwiseGemm:
N, K_w = qw.shape N, K_w = qw.shape
assert K == K_w, "K dimension mismatch between qx and qw" assert K == K_w, "K dimension mismatch between qx and qw"
tile_len = 128 tile_len = blockwise_fp8_block_len
# Calculate grid sizes without padding # Calculate grid sizes without padding
grid_m = (M + tile_len - 1) // tile_len grid_m = (M + tile_len - 1) // tile_len
grid_n = (N + tile_len - 1) // tile_len grid_n = (N + tile_len - 1) // tile_len
......
...@@ -7,7 +7,7 @@ import math ...@@ -7,7 +7,7 @@ import math
import torch import torch
from typing import Optional, Protocol, Tuple from typing import Optional, Protocol, Tuple
from references.quantize_scale_calc import scale_from_amax_tensor from references.quantize_scale_calc import scale_from_amax_tensor
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
@dataclasses.dataclass() @dataclasses.dataclass()
class QuantizeResult: class QuantizeResult:
...@@ -277,7 +277,7 @@ class BlockwiseQuantizerReference: ...@@ -277,7 +277,7 @@ class BlockwiseQuantizerReference:
return_transpose: bool = False, return_transpose: bool = False,
eps: float = 0.0, eps: float = 0.0,
pow_2_scales: bool = False, pow_2_scales: bool = False,
quant_tile_shape: Tuple[int, int] = (128, 128), quant_tile_shape: Tuple[int, int] = (blockwise_fp8_block_len, blockwise_fp8_block_len),
) -> QuantizeResult: ) -> QuantizeResult:
# sanity checks # sanity checks
assert x.dim() == 2 assert x.dim() == 2
...@@ -293,7 +293,7 @@ class BlockwiseQuantizerReference: ...@@ -293,7 +293,7 @@ class BlockwiseQuantizerReference:
torch.int8, torch.int8,
), "Unsupported quant dtype." ), "Unsupported quant dtype."
assert quant_tile_shape in ((1, 128), (128, 128)) assert quant_tile_shape in ((1, blockwise_fp8_block_len), (blockwise_fp8_block_len, blockwise_fp8_block_len))
if quant_tile_shape[0] == 1: if quant_tile_shape[0] == 1:
# Quantize row-wise # Quantize row-wise
return self.scale_munger.munge_scale_shapes_for_backend( return self.scale_munger.munge_scale_shapes_for_backend(
......
...@@ -8,7 +8,7 @@ import transformer_engine as te ...@@ -8,7 +8,7 @@ import transformer_engine as te
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import (FP8GlobalStateManager, blockwise_fp8_block_len)
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
Float8BlockQuantizer, Float8BlockQuantizer,
Float8BlockwiseQTensor, Float8BlockwiseQTensor,
...@@ -77,8 +77,9 @@ def cublas_gemm_fp8_blockwise_case( ...@@ -77,8 +77,9 @@ def cublas_gemm_fp8_blockwise_case(
assert not (use_bias and use_grad), "Bias grad not supported by GEMM" assert not (use_bias and use_grad), "Bias grad not supported by GEMM"
# Set quantize_op and quantization parameters # Set quantize_op and quantization parameters
x_quant_tile_shape = (1, 128) if is_x_1d_scaled else (128, 128) block_len = blockwise_fp8_block_len
w_quant_tile_shape = (1, 128) if is_w_1d_scaled else (128, 128) x_quant_tile_shape = (1, block_len) if is_x_1d_scaled else (block_len, block_len)
w_quant_tile_shape = (1, block_len) if is_w_1d_scaled else (block_len, block_len)
x_block_scaling_dim = 1 if is_x_1d_scaled else 2 x_block_scaling_dim = 1 if is_x_1d_scaled else 2
w_block_scaling_dim = 1 if is_w_1d_scaled else 2 w_block_scaling_dim = 1 if is_w_1d_scaled else 2
x_te_dtype = TE_DType[x_dtype] x_te_dtype = TE_DType[x_dtype]
...@@ -247,8 +248,9 @@ def cublas_gemm_test_constraint_enforced( ...@@ -247,8 +248,9 @@ def cublas_gemm_test_constraint_enforced(
out = None out = None
# Set quantize_op and quantization parameters # Set quantize_op and quantization parameters
x_quant_tile_shape = (1, 128) if is_x_1d_scaled else (128, 128) block_len = blockwise_fp8_block_len
w_quant_tile_shape = (1, 128) if is_w_1d_scaled else (128, 128) x_quant_tile_shape = (1, block_len) if is_x_1d_scaled else (block_len, block_len)
w_quant_tile_shape = (1, block_len) if is_w_1d_scaled else (block_len, block_len)
x_block_scaling_dim = 1 if is_x_1d_scaled else 2 x_block_scaling_dim = 1 if is_x_1d_scaled else 2
w_block_scaling_dim = 1 if is_w_1d_scaled else 2 w_block_scaling_dim = 1 if is_w_1d_scaled else 2
x_te_dtype = TE_DType[x_dtype] x_te_dtype = TE_DType[x_dtype]
......
...@@ -10,7 +10,7 @@ import pytest ...@@ -10,7 +10,7 @@ import pytest
import torch import torch
import transformer_engine as te import transformer_engine as te
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import (FP8GlobalStateManager, blockwise_fp8_block_len)
from transformer_engine.common.recipe import Float8BlockScaling from transformer_engine.common.recipe import Float8BlockScaling
from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
...@@ -99,9 +99,9 @@ def check_quantization_block_tiling_versus_reference( ...@@ -99,9 +99,9 @@ def check_quantization_block_tiling_versus_reference(
tile_size: Tuple[int, int], tile_size: Tuple[int, int],
) -> None: ) -> None:
te_dtype = TE_DType[quant_dtype] te_dtype = TE_DType[quant_dtype]
if tile_size == (1, 128): if tile_size in ((1, 128), (1, 64)):
block_scaling_dim = 1 block_scaling_dim = 1
elif tile_size == (128, 128): elif tile_size in ((128, 128), (64, 64)):
block_scaling_dim = 2 block_scaling_dim = 2
else: else:
raise ValueError("Non support tile size") raise ValueError("Non support tile size")
...@@ -214,7 +214,7 @@ def check_quantization_block_tiling_versus_reference( ...@@ -214,7 +214,7 @@ def check_quantization_block_tiling_versus_reference(
"return_transpose", [True, False], ids=["quantize_transpose", "quantize_only"] "return_transpose", [True, False], ids=["quantize_transpose", "quantize_only"]
) )
@pytest.mark.parametrize("pow_2_scales", [True], ids=["pow2scales"]) @pytest.mark.parametrize("pow_2_scales", [True], ids=["pow2scales"])
@pytest.mark.parametrize("tile_size", [(1, 128), (128, 128)], ids=["1DTile", "2DTile"]) @pytest.mark.parametrize("tile_size", [(1, 128), (128, 128), (1, 64), (64, 64)], ids=["1D128Tile", "2D128Tile", "1D64Tile", "2D64Tile"])
def test_quantization_block_tiling_versus_reference( def test_quantization_block_tiling_versus_reference(
x_dtype: torch.dtype, x_dtype: torch.dtype,
M: int, M: int,
...@@ -225,6 +225,8 @@ def test_quantization_block_tiling_versus_reference( ...@@ -225,6 +225,8 @@ def test_quantization_block_tiling_versus_reference(
pow_2_scales: bool, pow_2_scales: bool,
tile_size: Tuple[int, int], tile_size: Tuple[int, int],
) -> None: ) -> None:
if blockwise_fp8_block_len != tile_size[1]:
pytest.skip("Block len of blockwise is skipped by env.")
check_quantization_block_tiling_versus_reference( check_quantization_block_tiling_versus_reference(
x_dtype, M, N, quant_dtype, eps, return_transpose, pow_2_scales, tile_size x_dtype, M, N, quant_dtype, eps, return_transpose, pow_2_scales, tile_size
) )
...@@ -249,7 +251,7 @@ def test_quantization_block_tiling_versus_reference( ...@@ -249,7 +251,7 @@ def test_quantization_block_tiling_versus_reference(
"return_transpose", [True, False], ids=["quantize_transpose", "quantize_only"] "return_transpose", [True, False], ids=["quantize_transpose", "quantize_only"]
) )
@pytest.mark.parametrize("pow_2_scales", [False], ids=["fp32scales"]) @pytest.mark.parametrize("pow_2_scales", [False], ids=["fp32scales"])
@pytest.mark.parametrize("tile_size", [(1, 128), (128, 128)], ids=["1DTile", "2DTile"]) @pytest.mark.parametrize("tile_size", [(1, 128), (128, 128), (1, 64), (64, 64)], ids=["1D128Tile", "2D128Tile", "1D64Tile", "2D64Tile"])
def test_quantization_block_tiling_versus_reference_fp32_scales( def test_quantization_block_tiling_versus_reference_fp32_scales(
x_dtype: torch.dtype, x_dtype: torch.dtype,
M: int, M: int,
...@@ -260,6 +262,8 @@ def test_quantization_block_tiling_versus_reference_fp32_scales( ...@@ -260,6 +262,8 @@ def test_quantization_block_tiling_versus_reference_fp32_scales(
pow_2_scales: bool, pow_2_scales: bool,
tile_size: Tuple[int, int], tile_size: Tuple[int, int],
) -> None: ) -> None:
if blockwise_fp8_block_len != tile_size[1]:
pytest.skip("Block len of blockwise is skipped by env.")
check_quantization_block_tiling_versus_reference( check_quantization_block_tiling_versus_reference(
x_dtype, M, N, quant_dtype, eps, return_transpose, pow_2_scales, tile_size x_dtype, M, N, quant_dtype, eps, return_transpose, pow_2_scales, tile_size
) )
...@@ -277,7 +281,7 @@ def test_quantization_block_tiling_versus_reference_fp32_scales( ...@@ -277,7 +281,7 @@ def test_quantization_block_tiling_versus_reference_fp32_scales(
@pytest.mark.parametrize("quant_dtype", [torch.int8, torch.float8_e4m3fn, torch.float8_e5m2], ids=str) @pytest.mark.parametrize("quant_dtype", [torch.int8, torch.float8_e4m3fn, torch.float8_e5m2], ids=str)
@pytest.mark.parametrize("eps", [0], ids=["eps_0"]) @pytest.mark.parametrize("eps", [0], ids=["eps_0"])
@pytest.mark.parametrize("pow_2_scales", [True, False], ids=["pow2scales", "fp32scales"]) @pytest.mark.parametrize("pow_2_scales", [True, False], ids=["pow2scales", "fp32scales"])
@pytest.mark.parametrize("tile_size", [(128, 128)]) @pytest.mark.parametrize("tile_size", [(128, 128), (64, 64)], ids=["2D128Tile", "2D64Tile"])
@pytest.mark.parametrize("extrema_high", [False, True], ids=["zeros", "maxes"]) @pytest.mark.parametrize("extrema_high", [False, True], ids=["zeros", "maxes"])
def test_quantization_block_tiling_extrema_versus_reference( def test_quantization_block_tiling_extrema_versus_reference(
x_dtype: torch.dtype, x_dtype: torch.dtype,
...@@ -291,10 +295,12 @@ def test_quantization_block_tiling_extrema_versus_reference( ...@@ -291,10 +295,12 @@ def test_quantization_block_tiling_extrema_versus_reference(
) -> None: ) -> None:
# This test runs a single tile through a quantizer as a way to test # This test runs a single tile through a quantizer as a way to test
# branch coverage of scale computation. # branch coverage of scale computation.
if blockwise_fp8_block_len != tile_size[1]:
pytest.skip("Block len of blockwise is skipped by env.")
te_dtype = TE_DType[quant_dtype] te_dtype = TE_DType[quant_dtype]
if tile_size == (1, 128): if tile_size in ((1, 128), (1, 64)):
block_scaling_dim = 1 block_scaling_dim = 1
elif tile_size == (128, 128): elif tile_size in ((128, 128), (64, 64)):
block_scaling_dim = 2 block_scaling_dim = 2
else: else:
raise ValueError("Non support tile size") raise ValueError("Non support tile size")
......
...@@ -4,7 +4,7 @@ import transformer_engine as te ...@@ -4,7 +4,7 @@ import transformer_engine as te
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import (FP8GlobalStateManager, blockwise_fp8_block_len)
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
Float8BlockQuantizer, Float8BlockQuantizer,
Float8BlockwiseQTensor, Float8BlockwiseQTensor,
...@@ -82,8 +82,9 @@ def cublas_gemm_fp8_blockwise_case_fw( ...@@ -82,8 +82,9 @@ def cublas_gemm_fp8_blockwise_case_fw(
assert not (use_bias and use_grad), "Bias grad not supported by GEMM" assert not (use_bias and use_grad), "Bias grad not supported by GEMM"
# Set quantize_op and quantization parameters # Set quantize_op and quantization parameters
x_quant_tile_shape = (1, 128) if is_x_1d_scaled else (128, 128) block_len = blockwise_fp8_block_len
w_quant_tile_shape = (1, 128) if is_w_1d_scaled else (128, 128) x_quant_tile_shape = (1, block_len) if is_x_1d_scaled else (block_len, block_len)
w_quant_tile_shape = (1, block_len) if is_w_1d_scaled else (block_len, block_len)
x_block_scaling_dim = 1 if is_x_1d_scaled else 2 x_block_scaling_dim = 1 if is_x_1d_scaled else 2
w_block_scaling_dim = 1 if is_w_1d_scaled else 2 w_block_scaling_dim = 1 if is_w_1d_scaled else 2
x_te_dtype = TE_DType[x_dtype] x_te_dtype = TE_DType[x_dtype]
...@@ -196,7 +197,7 @@ def cublas_gemm_fp8_blockwise_case_fw( ...@@ -196,7 +197,7 @@ def cublas_gemm_fp8_blockwise_case_fw(
ref_scales_w = qw._columnwise_scale_inv if w_columnwise else qw._rowwise_scale_inv ref_scales_w = qw._columnwise_scale_inv if w_columnwise else qw._rowwise_scale_inv
y, _ = w8a8_block_int8_matmul( y, _ = w8a8_block_int8_matmul(
qx_data, qw_data, ref_scales_x, ref_scales_w, [128, 128], qx_data, qw_data, ref_scales_x, ref_scales_w, [block_len, block_len],
output_dtype=out_dtype output_dtype=out_dtype
) )
...@@ -265,8 +266,9 @@ def cublas_gemm_fp8_blockwise_case_bw_xgrad( ...@@ -265,8 +266,9 @@ def cublas_gemm_fp8_blockwise_case_bw_xgrad(
assert not (use_bias and use_grad), "Bias grad not supported by GEMM" assert not (use_bias and use_grad), "Bias grad not supported by GEMM"
# Set quantize_op and quantization parameters # Set quantize_op and quantization parameters
dout_quant_tile_shape = (1, 128) if is_dout_1d_scaled else (128, 128) block_len = blockwise_fp8_block_len
w_quant_tile_shape = (1, 128) if is_w_1d_scaled else (128, 128) dout_quant_tile_shape = (1, block_len) if is_dout_1d_scaled else (block_len, block_len)
w_quant_tile_shape = (1, block_len) if is_w_1d_scaled else (block_len, block_len)
dout_block_scaling_dim = 1 if is_dout_1d_scaled else 2 dout_block_scaling_dim = 1 if is_dout_1d_scaled else 2
w_block_scaling_dim = 1 if is_w_1d_scaled else 2 w_block_scaling_dim = 1 if is_w_1d_scaled else 2
dout_te_dtype = TE_DType[dout_dtype] dout_te_dtype = TE_DType[dout_dtype]
...@@ -373,7 +375,7 @@ def cublas_gemm_fp8_blockwise_case_bw_xgrad( ...@@ -373,7 +375,7 @@ def cublas_gemm_fp8_blockwise_case_bw_xgrad(
ref_scales_w = qw._columnwise_scale_inv if w_columnwise else qw._rowwise_scale_inv ref_scales_w = qw._columnwise_scale_inv if w_columnwise else qw._rowwise_scale_inv
y, _ = w8a8_block_int8_matmul( y, _ = w8a8_block_int8_matmul(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [128, 128], qdout_data, qw_data, ref_scales_dout, ref_scales_w, [block_len, block_len],
output_dtype=dx_dtype output_dtype=dx_dtype
) )
...@@ -441,8 +443,9 @@ def cublas_gemm_fp8_blockwise_case_bw_wgrad( ...@@ -441,8 +443,9 @@ def cublas_gemm_fp8_blockwise_case_bw_wgrad(
assert not (use_bias and use_grad), "Bias grad not supported by GEMM" assert not (use_bias and use_grad), "Bias grad not supported by GEMM"
# Set quantize_op and quantization parameters # Set quantize_op and quantization parameters
dout_quant_tile_shape = (1, 128) if is_dout_1d_scaled else (128, 128) block_len = blockwise_fp8_block_len
x_quant_tile_shape = (1, 128) if is_x_1d_scaled else (128, 128) dout_quant_tile_shape = (1, block_len) if is_dout_1d_scaled else (block_len, block_len)
x_quant_tile_shape = (1, block_len) if is_x_1d_scaled else (block_len, block_len)
dout_block_scaling_dim = 1 if is_dout_1d_scaled else 2 dout_block_scaling_dim = 1 if is_dout_1d_scaled else 2
x_block_scaling_dim = 1 if is_x_1d_scaled else 2 x_block_scaling_dim = 1 if is_x_1d_scaled else 2
dout_te_dtype = TE_DType[dout_dtype] dout_te_dtype = TE_DType[dout_dtype]
...@@ -552,7 +555,8 @@ def cublas_gemm_fp8_blockwise_case_bw_wgrad( ...@@ -552,7 +555,8 @@ def cublas_gemm_fp8_blockwise_case_bw_wgrad(
# print(f"ref_scales_dout.shape: {ref_scales_dout.shape}, ref_scales_x.shape: {ref_scales_x.shape}") # print(f"ref_scales_dout.shape: {ref_scales_dout.shape}, ref_scales_x.shape: {ref_scales_x.shape}")
y, _ = w8a8_block_int8_matmul_wgrad( y, _ = w8a8_block_int8_matmul_wgrad(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, [128, 128], qdout_data, qx_data, ref_scales_dout, ref_scales_x, dw.clone() if accumulate else None,
accumulate, [block_len, block_len],
output_dtype=dw_dtype output_dtype=dw_dtype
) )
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#ifndef TRANSFORMER_ENGINE_COMMON_COMMON_H_ #ifndef TRANSFORMER_ENGINE_COMMON_COMMON_H_
#define TRANSFORMER_ENGINE_COMMON_COMMON_H_ #define TRANSFORMER_ENGINE_COMMON_COMMON_H_
#include "util/system.h"
#ifndef __HIP_PLATFORM_AMD__ #ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h> #include <cudaTypedefs.h>
#endif #endif
...@@ -33,6 +34,10 @@ namespace transformer_engine { ...@@ -33,6 +34,10 @@ namespace transformer_engine {
std::string to_string(const DType type); std::string to_string(const DType type);
std::string to_string(const NVTEScalingMode &mode); std::string to_string(const NVTEScalingMode &mode);
inline int blockwise_fp8_block_len() {
return ::transformer_engine::getenv<int>("NVTE_BLOCKWISE_FP8_BLOCK_LEN", 128);
}
inline bool is_tensor_scaling(const NVTEScalingMode &mode) { inline bool is_tensor_scaling(const NVTEScalingMode &mode) {
return mode == NVTE_DELAYED_TENSOR_SCALING; return mode == NVTE_DELAYED_TENSOR_SCALING;
} }
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
namespace transformer_engine { namespace transformer_engine {
namespace fp8_block_scaling_recipe { namespace fp8_block_scaling_recipe {
constexpr int kTileDim64 = 64;
constexpr int kTileDim = 128; constexpr int kTileDim = 128;
constexpr int kThreadsPerBlock = 256; constexpr int kThreadsPerBlock = 256;
...@@ -116,10 +117,10 @@ __global__ void __launch_bounds__(kThreadsPerBlock) ...@@ -116,10 +117,10 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
if (h_in_input < h && w_in_input < w && idx_in_input >= start_offset && if (h_in_input < h && w_in_input < w && idx_in_input >= start_offset &&
idx_in_input < end_offset) { idx_in_input < end_offset) {
float inp = static_cast<float>(input_minus_offset[idx_in_input]) * scale; float inp = static_cast<float>(input_minus_offset[idx_in_input]) * scale;
if constexpr(std::is_same_v<OType, int8_t>) { if constexpr (std::is_same_v<OType, int8_t>) {
smem[h_in_smem][w_in_smem] = static_cast<OType>(lroundf(fmaxf(-127.0f, fminf(127.0f, inp)))); smem[h_in_smem][w_in_smem] =
} static_cast<OType>(lroundf(fmaxf(-127.0f, fminf(127.0f, inp))));
else { } else {
smem[h_in_smem][w_in_smem] = static_cast<OType>(inp); smem[h_in_smem][w_in_smem] = static_cast<OType>(inp);
} }
skip_store = false; skip_store = false;
...@@ -175,11 +176,171 @@ __global__ void __launch_bounds__(kThreadsPerBlock) ...@@ -175,11 +176,171 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
} }
} }
template <typename IType>
__global__ void __launch_bounds__(kThreadsPerBlock)
fp8_block_scaling_block_len64_compute_partial_amax_kernel(
const IType *input, float *amax_ptr, const size_t amax_stride_h, const size_t amax_stride_w,
const size_t h, const size_t w, const size_t start_offset, const size_t len) {
constexpr int kThreadsPerWarp = 32;
constexpr int kLoopsPerRow = kTileDim64 / kThreadsPerWarp;
constexpr int kNumWarps = kThreadsPerBlock / kThreadsPerWarp;
constexpr int kLoopsPerCol = kTileDim64 / kNumWarps;
const int tile_col = blockIdx.x;
const int tile_row = blockIdx.y;
const size_t end_offset = start_offset + len;
const IType *input_minus_offset = input - start_offset;
__shared__ float smem[kNumWarps];
float amax = 0.0f;
for (int loop_col = 0; loop_col < kLoopsPerCol; ++loop_col) {
size_t r = tile_row * kTileDim64 + loop_col * kNumWarps + threadIdx.x / kThreadsPerWarp;
for (int loop_row = 0; loop_row < kLoopsPerRow; ++loop_row) {
size_t c =
tile_col * kTileDim64 + loop_row * kThreadsPerWarp + (threadIdx.x % kThreadsPerWarp);
size_t idx = r * w + c;
if (r < h && c < w && idx >= start_offset && idx < end_offset) {
float other_amax = fabs(static_cast<float>(input_minus_offset[idx]));
__builtin_assume(amax >= 0);
__builtin_assume(other_amax >= 0);
amax = fmaxf(amax, other_amax);
}
}
}
for (int delta = kThreadsPerWarp / 2; delta > 0; delta /= 2) {
#ifdef __HIP_PLATFORM_AMD__
float other_amax = __shfl_down(amax, delta, kThreadsPerWarp);
#else
float other_amax = __shfl_down_sync(0xFFFFFFFF, amax, delta);
#endif
__builtin_assume(amax >= 0);
__builtin_assume(other_amax >= 0);
amax = fmaxf(amax, other_amax);
}
if (threadIdx.x % kThreadsPerWarp == 0) {
smem[threadIdx.x / kThreadsPerWarp] = amax;
}
__syncthreads();
if (threadIdx.x == 0) {
for (int i = 0; i < kNumWarps; ++i) {
float other_amax = smem[i];
__builtin_assume(amax >= 0);
__builtin_assume(other_amax >= 0);
amax = fmaxf(amax, other_amax);
}
amax_ptr[tile_row * amax_stride_h + tile_col * amax_stride_w] = amax;
}
}
template <typename IType, typename OType, bool kWidthAligned>
__global__ void __launch_bounds__(kThreadsPerBlock)
fp8_block_scaling_block_len64_partial_cast_kernel(const IType *input, OType *output,
const float *scale_ptr,
const size_t scale_stride_h,
const size_t scale_stride_w, const size_t h,
const size_t w, const size_t start_offset,
const size_t len) {
using transformer_engine::Vec;
static_assert(sizeof(OType) == 1);
constexpr int kNumOutputElemsPerBank = 4 / sizeof(OType);
constexpr int kThreadsPerWarp = 32;
constexpr int kLoopsPerRow = kTileDim64 / kThreadsPerWarp;
constexpr int kNumWarps = kThreadsPerBlock / kThreadsPerWarp;
constexpr int kRowsPerWarp = kTileDim64 / kNumWarps;
__shared__ OType smem[kTileDim64][kTileDim64 + kNumOutputElemsPerBank];
const int tile_w = blockIdx.x;
const int tile_h = blockIdx.y;
const size_t end_offset = start_offset + len;
const IType *input_minus_offset = input - start_offset;
OType *output_minus_offset = output - start_offset;
const float scale = scale_ptr[tile_h * scale_stride_h + tile_w * scale_stride_w];
// Load input data into shared memory
bool skip_store = true;
for (int i = 0; i < kRowsPerWarp; ++i) {
for (int j = 0; j < kLoopsPerRow; ++j) {
const int h_in_smem = threadIdx.x / kThreadsPerWarp * kRowsPerWarp + i;
const int w_in_smem = threadIdx.x % kThreadsPerWarp + kThreadsPerWarp * j;
const int h_in_input = tile_h * kTileDim64 + h_in_smem;
const int w_in_input = tile_w * kTileDim64 + w_in_smem;
const size_t idx_in_input = static_cast<size_t>(h_in_input) * w + w_in_input;
if (h_in_input < h && w_in_input < w && idx_in_input >= start_offset &&
idx_in_input < end_offset) {
float inp = static_cast<float>(input_minus_offset[idx_in_input]) * scale;
if constexpr (std::is_same_v<OType, int8_t>) {
smem[h_in_smem][w_in_smem] =
static_cast<OType>(lroundf(fmaxf(-127.0f, fminf(127.0f, inp))));
} else {
smem[h_in_smem][w_in_smem] = static_cast<OType>(inp);
}
skip_store = false;
}
}
}
for (int delta = kThreadsPerWarp / 2; delta > 0; delta /= 2) {
#ifdef __HIP_PLATFORM_AMD__
bool other_skip_store = __shfl_down(skip_store, delta, kThreadsPerWarp);
#else
bool other_skip_store = __shfl_down_sync(0xFFFFFFFF, skip_store, delta);
#endif
skip_store = skip_store && other_skip_store;
}
#ifdef __HIP_PLATFORM_AMD__
skip_store = __shfl(skip_store, 0, kThreadsPerWarp);
#else
skip_store = __shfl_sync(0xFFFFFFFF, skip_store, 0);
#endif
if (skip_store) {
return;
}
// Store the casted data into the output.
// Note that this store operation might write "out-of-bounds", but it is intentional:
// 1. The "out-of-bounds" here only crosses the boundary of the "local shard" (i.e., the region
// from start_offset to end_offset), not the boundary of the entire output memory. Therefore,
// this out-of-bounds write will not cause illegal memory access.
// 2. We assume that the subsequent all-gather operation happens in-place, so any parts that
// should not be updated here will be overwritten by the all-gather.
// This tricky approach allows us to avoid checking whether each output index falls within
// [start, end), resulting in a significant performance improvement.
Vec<OType, kNumOutputElemsPerBank> vec_output;
for (int i = 0; i < kRowsPerWarp; ++i) {
const int row_in_smem = threadIdx.x / kThreadsPerWarp * kRowsPerWarp + i;
const int col_in_smem = threadIdx.x % kThreadsPerWarp * kNumOutputElemsPerBank;
for (int j = 0; j < kNumOutputElemsPerBank; ++j) {
vec_output.data.elt[j] = smem[row_in_smem][col_in_smem + j];
}
const int row_in_output = tile_h * kTileDim64 + row_in_smem;
const int col_in_output = tile_w * kTileDim64 + col_in_smem;
const size_t idx_in_output = static_cast<size_t>(row_in_output) * w + col_in_output;
if (row_in_output < h) {
if constexpr (kWidthAligned) {
vec_output.store_to(output_minus_offset + idx_in_output);
} else {
int num = min(static_cast<size_t>(kNumOutputElemsPerBank),
static_cast<size_t>(col_in_output < w ? w - col_in_output : 0));
vec_output.store_to_elts(output_minus_offset, idx_in_output, num);
}
}
}
}
void fp8_block_scaling_compute_partial_amax(const Tensor inp, Tensor amax, size_t h, size_t w, void fp8_block_scaling_compute_partial_amax(const Tensor inp, Tensor amax, size_t h, size_t w,
size_t amax_stride_h, size_t amax_stride_w, size_t amax_stride_h, size_t amax_stride_w,
size_t start_offset, size_t block_len, size_t start_offset, size_t block_len,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_CHECK(block_len == 128, "Currently only block_len = 128 is supported"); NVTE_CHECK(block_len == 128 || block_len == 64,
"Currently only block_len = 128 or 64 is supported");
size_t len = inp.numel(); size_t len = inp.numel();
...@@ -187,26 +348,39 @@ void fp8_block_scaling_compute_partial_amax(const Tensor inp, Tensor amax, size_ ...@@ -187,26 +348,39 @@ void fp8_block_scaling_compute_partial_amax(const Tensor inp, Tensor amax, size_
assert(start_offset < h * w); assert(start_offset < h * w);
assert(start_offset + len <= h * w); assert(start_offset + len <= h * w);
size_t blocks_x = (w + kTileDim - 1) / kTileDim; size_t blocks_x = (w + block_len - 1) / block_len;
size_t blocks_y = (h + kTileDim - 1) / kTileDim; size_t blocks_y = (h + block_len - 1) / block_len;
assert(blocks_x <= std::numeric_limits<unsigned int>::max()); assert(blocks_x <= std::numeric_limits<unsigned int>::max());
assert(blocks_y <= std::numeric_limits<unsigned int>::max()); assert(blocks_y <= std::numeric_limits<unsigned int>::max());
dim3 grid(blocks_x, blocks_y); dim3 grid(blocks_x, blocks_y);
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
inp.dtype(), inp_dtype, inp.dtype(), inp_dtype, while (true) {
if (128 == block_len) {
fp8_block_scaling_compute_partial_amax_kernel<inp_dtype> fp8_block_scaling_compute_partial_amax_kernel<inp_dtype>
<<<grid, kThreadsPerBlock, 0, stream>>>(reinterpret_cast<const inp_dtype *>(inp.data.dptr), <<<grid, kThreadsPerBlock, 0, stream>>>(
reinterpret_cast<float *>(amax.data.dptr), reinterpret_cast<const inp_dtype *>(inp.data.dptr),
amax_stride_h, amax_stride_w, h, w, start_offset, reinterpret_cast<float *>(amax.data.dptr), amax_stride_h, amax_stride_w, h, w,
len);) start_offset, len);
break;
}
if (64 == block_len) {
fp8_block_scaling_block_len64_compute_partial_amax_kernel<inp_dtype>
<<<grid, kThreadsPerBlock, 0, stream>>>(
reinterpret_cast<const inp_dtype *>(inp.data.dptr),
reinterpret_cast<float *>(amax.data.dptr), amax_stride_h, amax_stride_w, h, w,
start_offset, len);
break;
}
})
} }
void fp8_block_scaling_partial_cast(const Tensor inp, Tensor out, const Tensor scale, size_t h, void fp8_block_scaling_partial_cast(const Tensor inp, Tensor out, const Tensor scale, size_t h,
size_t w, size_t scale_stride_h, size_t scale_stride_w, size_t w, size_t scale_stride_h, size_t scale_stride_w,
size_t start_offset, size_t block_len, const DType out_dtype, size_t start_offset, size_t block_len, const DType out_dtype,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_CHECK(block_len == 128, "Currently only block_len = 128 is supported"); NVTE_CHECK(block_len == 128 || block_len == 64,
"Currently only block_len = 128 or 64 is supported");
size_t len = inp.numel(); size_t len = inp.numel();
...@@ -214,8 +388,8 @@ void fp8_block_scaling_partial_cast(const Tensor inp, Tensor out, const Tensor s ...@@ -214,8 +388,8 @@ void fp8_block_scaling_partial_cast(const Tensor inp, Tensor out, const Tensor s
assert(start_offset < h * w); assert(start_offset < h * w);
assert(start_offset + len <= h * w); assert(start_offset + len <= h * w);
size_t blocks_x = (w + kTileDim - 1) / kTileDim; size_t blocks_x = (w + block_len - 1) / block_len;
size_t blocks_y = (h + kTileDim - 1) / kTileDim; size_t blocks_y = (h + block_len - 1) / block_len;
assert(blocks_x <= std::numeric_limits<unsigned int>::max()); assert(blocks_x <= std::numeric_limits<unsigned int>::max());
assert(blocks_y <= std::numeric_limits<unsigned int>::max()); assert(blocks_y <= std::numeric_limits<unsigned int>::max());
dim3 grid(blocks_x, blocks_y); dim3 grid(blocks_x, blocks_y);
...@@ -225,13 +399,27 @@ void fp8_block_scaling_partial_cast(const Tensor inp, Tensor out, const Tensor s ...@@ -225,13 +399,27 @@ void fp8_block_scaling_partial_cast(const Tensor inp, Tensor out, const Tensor s
TRANSFORMER_ENGINE_TYPE_SWITCH_8BIT( TRANSFORMER_ENGINE_TYPE_SWITCH_8BIT(
out_dtype, fp8_type, out_dtype, fp8_type,
TRANSFORMER_ENGINE_SWITCH_CONDITION( TRANSFORMER_ENGINE_SWITCH_CONDITION(
w % kTileDim == 0, kWidthAligned, w % block_len == 0, kWidthAligned, while (true) {
if (128 == block_len) {
fp8_block_scaling_partial_cast_kernel<inp_dtype, fp8_type, kWidthAligned> fp8_block_scaling_partial_cast_kernel<inp_dtype, fp8_type, kWidthAligned>
<<<grid, kThreadsPerBlock, 0, stream>>>( <<<grid, kThreadsPerBlock, 0, stream>>>(
reinterpret_cast<const inp_dtype *>(inp.data.dptr), reinterpret_cast<const inp_dtype *>(inp.data.dptr),
reinterpret_cast<fp8_type *>(out.data.dptr), reinterpret_cast<fp8_type *>(out.data.dptr),
reinterpret_cast<const float *>(scale.data.dptr), scale_stride_h, scale_stride_w, reinterpret_cast<const float *>(scale.data.dptr), scale_stride_h,
h, w, start_offset, len);))) scale_stride_w, h, w, start_offset, len);
break;
}
if (64 == block_len) {
fp8_block_scaling_block_len64_partial_cast_kernel<inp_dtype, fp8_type,
kWidthAligned>
<<<grid, kThreadsPerBlock, 0, stream>>>(
reinterpret_cast<const inp_dtype *>(inp.data.dptr),
reinterpret_cast<fp8_type *>(out.data.dptr),
reinterpret_cast<const float *>(scale.data.dptr), scale_stride_h,
scale_stride_w, h, w, start_offset, len);
break;
}
})))
} }
} // namespace fp8_block_scaling_recipe } // namespace fp8_block_scaling_recipe
......
...@@ -39,6 +39,7 @@ constexpr size_t WARP_TILE_DIM_Y = 64; ...@@ -39,6 +39,7 @@ constexpr size_t WARP_TILE_DIM_Y = 64;
constexpr size_t THREAD_TILE_DIM_X = 16; constexpr size_t THREAD_TILE_DIM_X = 16;
constexpr size_t THREAD_TILE_DIM_Y = 4; constexpr size_t THREAD_TILE_DIM_Y = 4;
#else #else
constexpr size_t BLOCK_TILE_DIM64 = 64;
constexpr size_t BLOCK_TILE_DIM = 128; constexpr size_t BLOCK_TILE_DIM = 128;
constexpr size_t WARP_TILE_DIM_X = 64; constexpr size_t WARP_TILE_DIM_X = 64;
constexpr size_t WARP_TILE_DIM_Y = 32; constexpr size_t WARP_TILE_DIM_Y = 32;
...@@ -60,6 +61,11 @@ constexpr size_t NUM_WARPS_X_IN_BLOCK = BLOCK_TILE_DIM / WARP_TILE_DIM_X; ...@@ -60,6 +61,11 @@ constexpr size_t NUM_WARPS_X_IN_BLOCK = BLOCK_TILE_DIM / WARP_TILE_DIM_X;
constexpr size_t NUM_WARPS_Y_IN_BLOCK = BLOCK_TILE_DIM / WARP_TILE_DIM_Y; constexpr size_t NUM_WARPS_Y_IN_BLOCK = BLOCK_TILE_DIM / WARP_TILE_DIM_Y;
constexpr size_t NUM_WARPS_IN_BLOCK = NUM_WARPS_X_IN_BLOCK * NUM_WARPS_Y_IN_BLOCK; constexpr size_t NUM_WARPS_IN_BLOCK = NUM_WARPS_X_IN_BLOCK * NUM_WARPS_Y_IN_BLOCK;
constexpr size_t THREADS_PER_BLOCK64 = BLOCK_TILE_DIM64 * BLOCK_TILE_DIM64 / ELE_PER_THREAD;
constexpr size_t NUM_WARPS_X_IN_BLOCK64 = BLOCK_TILE_DIM64 / WARP_TILE_DIM_X;
constexpr size_t NUM_WARPS_Y_IN_BLOCK64 = BLOCK_TILE_DIM64 / WARP_TILE_DIM_Y;
constexpr size_t NUM_WARPS_IN_BLOCK64 = NUM_WARPS_X_IN_BLOCK64 * NUM_WARPS_Y_IN_BLOCK64;
constexpr size_t NUM_THREADS_X_IN_WARP = WARP_TILE_DIM_X / THREAD_TILE_DIM_X; constexpr size_t NUM_THREADS_X_IN_WARP = WARP_TILE_DIM_X / THREAD_TILE_DIM_X;
constexpr size_t NUM_THREADS_Y_IN_WARP = kThreadsPerWarp / NUM_THREADS_X_IN_WARP; constexpr size_t NUM_THREADS_Y_IN_WARP = kThreadsPerWarp / NUM_THREADS_X_IN_WARP;
...@@ -188,11 +194,11 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) ...@@ -188,11 +194,11 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
CType scale_data = block_tile_scale; CType scale_data = block_tile_scale;
OType scaled_elt = 0; OType scaled_elt = 0;
if constexpr(std::is_same_v<OType, int8_t>) { if constexpr (std::is_same_v<OType, int8_t>) {
scaled_elt = scaled_elt = static_cast<OType>(lroundf(
static_cast<OType>(lroundf(fmaxf(-127.0f, fminf(127.0f, static_cast<CType>(thrd_tile_input[i].data.elt[j]) * scale_data)))); fmaxf(-127.0f,
} fminf(127.0f, static_cast<CType>(thrd_tile_input[i].data.elt[j]) * scale_data))));
else { } else {
scaled_elt = scaled_elt =
static_cast<OType>(static_cast<CType>(thrd_tile_input[i].data.elt[j]) * scale_data); static_cast<OType>(static_cast<CType>(thrd_tile_input[i].data.elt[j]) * scale_data);
} }
...@@ -439,11 +445,382 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose ...@@ -439,11 +445,382 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose
// Step 3: Store cast output // Step 3: Store cast output
CType scale_data = block_tile_scale; CType scale_data = block_tile_scale;
OType scaled_elt = 0; OType scaled_elt = 0;
if constexpr(std::is_same_v<OType, int8_t>) { if constexpr (std::is_same_v<OType, int8_t>) {
scaled_elt = static_cast<OType>(lroundf(fmaxf(
-127.0f,
fminf(127.0f, static_cast<CType>(thrd_tile_input[i].data.elt[j]) * scale_data))));
} else {
scaled_elt =
static_cast<OType>(static_cast<CType>(thrd_tile_input[i].data.elt[j]) * scale_data);
}
tmp_output_c.data.elt[j] = scaled_elt;
// Step 4: do transpose within thread tile
if constexpr (kReturnTranspose) {
thrd_tile_out_trans[j].data.elt[i] = scaled_elt;
}
}
tmp_output_c.store_to_elts(output_c + thread_tile_start_idx + i * row_length, 0,
thread_tile_ncols);
}
if constexpr (kReturnTranspose) {
const size_t block_tile_t_start_idx =
tile_id_x * BLOCK_TILE_DIM * num_rows + tile_id_y * BLOCK_TILE_DIM;
const size_t warp_tile_t_start_idx =
block_tile_t_start_idx +
warp_id_in_block_x * THREAD_TILE_DIM_X * NUM_THREADS_X_IN_WARP * num_rows +
warp_id_in_block_y * THREAD_TILE_DIM_Y * NUM_THREADS_Y_IN_WARP;
const size_t thread_tile_t_start_idx = warp_tile_t_start_idx +
tid_in_warp_x * THREAD_TILE_DIM_X * num_rows +
tid_in_warp_y * THREAD_TILE_DIM_Y;
#pragma unroll
for (int i = 0; i < thread_tile_ncols; i++) {
thrd_tile_out_trans[i].store_to_elts(output_t + thread_tile_t_start_idx + i * num_rows, 0,
thread_tile_nrows);
}
}
}
}
template <bool kReturnTranspose, typename CType, typename IType, typename OType>
__global__ void __launch_bounds__(THREADS_PER_BLOCK64)
block_scaled_block_len64_cast_transpose_kernel(
const IType* const input, OType* const output_c, OType* const output_t,
CType* const tile_scales_inv_c, CType* const tile_scales_inv_t, const size_t row_length,
const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y,
const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon,
bool pow_2_scaling) {
using IVec = Vec<IType, THREAD_TILE_DIM_X>;
using OVecCast = Vec<OType, THREAD_TILE_DIM_X>;
using OVecTrans = Vec<OType, THREAD_TILE_DIM_Y>;
// shared mem for amax reduction in entire block, each warp produces one amax, there are
// NUM_WARPS_IN_BLOCK amax to reduce
__shared__ CType block_tile_amax_shared[NUM_WARPS_IN_BLOCK64];
IVec thrd_tile_input[THREAD_TILE_DIM_Y];
constexpr int THREAD_TILE_DIM_X_ = kReturnTranspose ? THREAD_TILE_DIM_X : 1;
OVecTrans thrd_tile_out_trans[THREAD_TILE_DIM_X_];
const int tid_in_warp = threadIdx.x % kThreadsPerWarp;
const int tid_in_warp_x = tid_in_warp % NUM_THREADS_X_IN_WARP;
const int tid_in_warp_y = tid_in_warp / NUM_THREADS_X_IN_WARP;
const int warp_id_in_block = threadIdx.x / kThreadsPerWarp;
const int warp_id_in_block_x = warp_id_in_block % NUM_WARPS_X_IN_BLOCK64;
const int warp_id_in_block_y = warp_id_in_block / NUM_WARPS_X_IN_BLOCK64;
// This is ONLY true if the input is a full tile
const int tile_id_x = blockIdx.x;
const int tile_id_y = blockIdx.y;
const size_t block_tile_start_idx =
tile_id_y * BLOCK_TILE_DIM64 * row_length + tile_id_x * BLOCK_TILE_DIM64;
const size_t warp_tile_start_idx =
block_tile_start_idx +
warp_id_in_block_y * THREAD_TILE_DIM_Y * NUM_THREADS_Y_IN_WARP * row_length +
warp_id_in_block_x * THREAD_TILE_DIM_X * NUM_THREADS_X_IN_WARP;
const size_t thread_tile_start_idx = warp_tile_start_idx +
tid_in_warp_y * THREAD_TILE_DIM_Y * row_length +
tid_in_warp_x * THREAD_TILE_DIM_X;
CType warp_tile_amax;
CType block_tile_amax;
CType block_tile_scale;
CType amax = 0;
// Step 1: Load a block tile of input data into thread tiles on registers
#pragma unroll
for (int i = 0; i < THREAD_TILE_DIM_Y; i++) {
thrd_tile_input[i].load_from(input + thread_tile_start_idx + i * row_length);
}
// Step 2: calculate block tile amax and scale
// Calculate thread_tile amax
for (int i = 0; i < THREAD_TILE_DIM_Y; i++) {
#pragma unroll
for (int j = 0; j < THREAD_TILE_DIM_X; j++) {
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(static_cast<CType>(thrd_tile_input[i].data.elt[j])));
}
}
// Reduce amax in the warp (32x32 tile)
warp_tile_amax = warp_reduce_max<kThreadsPerWarp>(amax);
// broadcast the amax to all threads in a warp from the lane 0
constexpr int lane_zero = 0;
#ifdef __HIP_PLATFORM_AMD__
warp_tile_amax = __shfl(warp_tile_amax, lane_zero, THREADS_PER_WARP);
#else
warp_tile_amax = __shfl_sync(0xFFFFFFFF, warp_tile_amax, lane_zero);
#endif
// reduce warp_tile_amax across multiple warps in a thread block using shared mem
if (tid_in_warp == 0) {
block_tile_amax_shared[warp_id_in_block_y * NUM_WARPS_X_IN_BLOCK64 + warp_id_in_block_x] =
warp_tile_amax;
}
__syncthreads();
// only 8 elements needs reduction, if using reduction tree, multiple _syncthreads will be needed,
// instead we just let thread 0 do the job
if (threadIdx.x == 0) {
CType blk_amax = block_tile_amax_shared[0];
#pragma unroll
for (int idx = 1; idx < NUM_WARPS_IN_BLOCK64; idx++) {
blk_amax = fmaxf(blk_amax, block_tile_amax_shared[idx]);
}
block_tile_amax_shared[0] = blk_amax;
}
__syncthreads();
block_tile_amax = block_tile_amax_shared[0];
block_tile_scale =
compute_scale_from_types<IType, OType>(block_tile_amax, epsilon, pow_2_scaling);
if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value);
const CType scale_inv = 1.0f / block_tile_scale;
size_t row_idx = tile_id_y;
size_t col_idx = tile_id_x;
tile_scales_inv_c[row_idx * scale_stride_y + col_idx * scale_stride_x] = scale_inv;
if constexpr (kReturnTranspose) {
row_idx = tile_id_x;
col_idx = tile_id_y;
tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv;
}
}
// Step 3: Store cast output, Step 4: do transpose within thread tile
OVecCast tmp_output_c;
for (int i = 0; i < THREAD_TILE_DIM_Y; i++) {
#pragma unroll
for (int j = 0; j < THREAD_TILE_DIM_X; j++) {
// Step 3: Store cast output
CType scale_data = block_tile_scale;
OType scaled_elt = 0;
if constexpr (std::is_same_v<OType, int8_t>) {
scaled_elt = static_cast<OType>(lroundf(
fmaxf(-127.0f,
fminf(127.0f, static_cast<CType>(thrd_tile_input[i].data.elt[j]) * scale_data))));
} else {
scaled_elt = scaled_elt =
static_cast<OType>(lroundf(fmaxf(-127.0f, fminf(127.0f, static_cast<CType>(thrd_tile_input[i].data.elt[j]) * scale_data)))); static_cast<OType>(static_cast<CType>(thrd_tile_input[i].data.elt[j]) * scale_data);
}
tmp_output_c.data.elt[j] = scaled_elt;
// Step 4: do transpose within thread tile
if constexpr (kReturnTranspose) {
thrd_tile_out_trans[j].data.elt[i] = scaled_elt;
}
}
tmp_output_c.store_to(output_c + thread_tile_start_idx + i * row_length);
}
// Step 4: store transpose into shared memory
if constexpr (kReturnTranspose) {
// Step 4 Alternative (when TMA is not available, skip writing to shared memory)
const size_t block_tile_t_start_idx =
tile_id_x * BLOCK_TILE_DIM64 * num_rows + tile_id_y * BLOCK_TILE_DIM64;
const size_t warp_tile_t_start_idx =
block_tile_t_start_idx +
warp_id_in_block_x * THREAD_TILE_DIM_X * NUM_THREADS_X_IN_WARP * num_rows +
warp_id_in_block_y * THREAD_TILE_DIM_Y * NUM_THREADS_Y_IN_WARP;
const size_t thread_tile_t_start_idx = warp_tile_t_start_idx +
tid_in_warp_x * THREAD_TILE_DIM_X * num_rows +
tid_in_warp_y * THREAD_TILE_DIM_Y;
#pragma unroll
for (int i = 0; i < THREAD_TILE_DIM_X; i++) {
thrd_tile_out_trans[i].store_to(output_t + thread_tile_t_start_idx + i * num_rows);
}
}
}
template <bool kReturnTranspose, typename CType, typename IType, typename OType>
__global__ void __launch_bounds__(THREADS_PER_BLOCK64)
block_scaled_block_len64_cast_transpose_kernel_notaligned(
const IType* const input, OType* const output_c, OType* const output_t,
CType* const tile_scales_inv_c, CType* const tile_scales_inv_t, const size_t row_length,
const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y,
const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon,
bool pow_2_scaling) {
using IVec = Vec<IType, THREAD_TILE_DIM_X>;
using OVecCast = Vec<OType, THREAD_TILE_DIM_X>;
using OVecTrans = Vec<OType, THREAD_TILE_DIM_Y>;
// shared mem for amax reduction in entire block, each warp produces one amax, there are
// NUM_WARPS_IN_BLOCK amax to reduce
__shared__ CType block_tile_amax_shared[NUM_WARPS_IN_BLOCK64];
IVec thrd_tile_input[THREAD_TILE_DIM_Y];
constexpr int THREAD_TILE_DIM_X_ = kReturnTranspose ? THREAD_TILE_DIM_X : 1;
OVecTrans thrd_tile_out_trans[THREAD_TILE_DIM_X_];
const int tid_in_warp = threadIdx.x % kThreadsPerWarp;
const int tid_in_warp_x = tid_in_warp % NUM_THREADS_X_IN_WARP;
const int tid_in_warp_y = tid_in_warp / NUM_THREADS_X_IN_WARP;
const int warp_id_in_block = threadIdx.x / kThreadsPerWarp;
const int warp_id_in_block_x = warp_id_in_block % NUM_WARPS_X_IN_BLOCK64;
const int warp_id_in_block_y = warp_id_in_block / NUM_WARPS_X_IN_BLOCK64;
const int tile_id_x = blockIdx.x;
const int tile_id_y = blockIdx.y;
const size_t block_tile_start_row_idx = tile_id_y * BLOCK_TILE_DIM64;
const size_t block_tile_start_col_idx = tile_id_x * BLOCK_TILE_DIM64;
const size_t block_tile_start_idx =
block_tile_start_row_idx * row_length + block_tile_start_col_idx;
const size_t warp_tile_start_idx =
block_tile_start_idx +
warp_id_in_block_y * THREAD_TILE_DIM_Y * NUM_THREADS_Y_IN_WARP * row_length +
warp_id_in_block_x * THREAD_TILE_DIM_X * NUM_THREADS_X_IN_WARP;
const size_t thread_tile_start_idx = warp_tile_start_idx +
tid_in_warp_y * THREAD_TILE_DIM_Y * row_length +
tid_in_warp_x * THREAD_TILE_DIM_X;
// handle non-full tile
// check for three cases: full thread tile, nonfull thread tile, empty thread tile
// for empty thread tile, directly write zero to the transposed shared mem buffer
// for nonfull thread tile, fill zero to thread tile and act as if it's full
const size_t thread_tile_start_row_idx =
tile_id_y * BLOCK_TILE_DIM64 +
warp_id_in_block_y * THREAD_TILE_DIM_Y * NUM_THREADS_Y_IN_WARP +
tid_in_warp_y * THREAD_TILE_DIM_Y;
const size_t thread_tile_start_col_idx =
tile_id_x * BLOCK_TILE_DIM64 +
warp_id_in_block_x * THREAD_TILE_DIM_X * NUM_THREADS_X_IN_WARP +
tid_in_warp_x * THREAD_TILE_DIM_X;
const size_t thread_tile_end_row_idx = thread_tile_start_row_idx + THREAD_TILE_DIM_Y - 1;
const size_t thread_tile_end_col_idx = thread_tile_start_col_idx + THREAD_TILE_DIM_X - 1;
bool full_thrd_tile =
(thread_tile_end_row_idx < num_rows) && (thread_tile_end_col_idx < row_length);
bool empty_thrd_tile =
(thread_tile_start_row_idx >= num_rows) || (thread_tile_start_col_idx >= row_length);
bool nonfull_thrd_tile = (!full_thrd_tile) && (!empty_thrd_tile);
const size_t thread_tile_ncols =
MIN(THREAD_TILE_DIM_X,
(MIN(thread_tile_end_col_idx, row_length - 1) - thread_tile_start_col_idx + 1));
const size_t thread_tile_nrows =
MIN(THREAD_TILE_DIM_Y,
(MIN(thread_tile_end_row_idx, num_rows - 1) - thread_tile_start_row_idx + 1));
CType warp_tile_amax;
CType block_tile_amax;
CType block_tile_scale;
CType amax = 0;
if (!empty_thrd_tile) {
// Step 1: Load a block tile of input data into thread tiles on registers
// Edge case: nonfull thread tile case, will use the partial load function here
if (nonfull_thrd_tile) {
#pragma unroll
for (int i = 0; i < THREAD_TILE_DIM_Y; i++) {
if (i >= thread_tile_nrows) {
thrd_tile_input[i].clear();
} else {
thrd_tile_input[i].load_from_elts(input + thread_tile_start_idx + i * row_length, 0,
thread_tile_ncols);
}
}
} else {
#pragma unroll
for (int i = 0; i < THREAD_TILE_DIM_Y; i++) {
thrd_tile_input[i].load_from_elts(input + thread_tile_start_idx + i * row_length, 0,
THREAD_TILE_DIM_X);
}
}
// Step 2: calculate block tile amax and scale
// Calculate thread_tile amax
for (int i = 0; i < THREAD_TILE_DIM_Y; i++) {
#pragma unroll
for (int j = 0; j < THREAD_TILE_DIM_X; j++) {
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(static_cast<CType>(thrd_tile_input[i].data.elt[j])));
}
}
}
// Reduce amax in the warp (32x32 tile)
warp_tile_amax = warp_reduce_max<kThreadsPerWarp>(amax);
// broadcast the amax to all threads in a warp from the lane 0
constexpr int lane_zero = 0;
#ifdef __HIP_PLATFORM_AMD__
warp_tile_amax = __shfl(warp_tile_amax, lane_zero, THREADS_PER_WARP);
#else
warp_tile_amax = __shfl_sync(0xFFFFFFFF, warp_tile_amax, lane_zero);
#endif
// reduce warp_tile_amax across multiple warps in a thread block using shared mem
if (tid_in_warp == 0) {
block_tile_amax_shared[warp_id_in_block_y * NUM_WARPS_X_IN_BLOCK64 + warp_id_in_block_x] =
warp_tile_amax;
}
__syncthreads();
// only 8 elements needs reduction, if using reduction tree, multiple _syncthreads will be needed,
// instead we just let thread 0 do the job
if (threadIdx.x == 0) {
CType blk_amax = block_tile_amax_shared[0];
#pragma unroll
for (int idx = 1; idx < NUM_WARPS_IN_BLOCK64; idx++) {
blk_amax = fmaxf(blk_amax, block_tile_amax_shared[idx]);
}
block_tile_amax_shared[0] = blk_amax;
} }
else { __syncthreads();
block_tile_amax = block_tile_amax_shared[0];
block_tile_scale =
compute_scale_from_types<IType, OType>(block_tile_amax, epsilon, pow_2_scaling);
if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value);
const CType scale_inv = 1.0f / block_tile_scale;
size_t row_idx = tile_id_y;
size_t col_idx = tile_id_x;
tile_scales_inv_c[row_idx * scale_stride_y + col_idx * scale_stride_x] = scale_inv;
if constexpr (kReturnTranspose) {
row_idx = tile_id_x;
col_idx = tile_id_y;
tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv;
}
}
// Step 3: Store cast output, Step 4: do transpose within thread tile
// Edge case: in the non-full tile case, there are three subcases
// for full thread tile, it's the same thing here
// for nonfull thread tile, pay attention when saving tmp_output_c to global
// memory, cannot vec store_to, but need to elt store to for empty tile,
// it should not enter this step, skip to Step 4
// set thrd_tile_out_trans to all zero
if constexpr (kReturnTranspose) {
#pragma unroll
for (int j = 0; j < THREAD_TILE_DIM_X; j++) {
thrd_tile_out_trans[j].clear();
}
}
if (!empty_thrd_tile) {
OVecCast tmp_output_c;
for (int i = 0; i < THREAD_TILE_DIM_Y; i++) {
if (i >= thread_tile_nrows) {
continue;
}
#pragma unroll
for (int j = 0; j < THREAD_TILE_DIM_X; j++) {
// Step 3: Store cast output
CType scale_data = block_tile_scale;
OType scaled_elt = 0;
if constexpr (std::is_same_v<OType, int8_t>) {
scaled_elt = static_cast<OType>(lroundf(fmaxf(
-127.0f,
fminf(127.0f, static_cast<CType>(thrd_tile_input[i].data.elt[j]) * scale_data))));
} else {
scaled_elt = scaled_elt =
static_cast<OType>(static_cast<CType>(thrd_tile_input[i].data.elt[j]) * scale_data); static_cast<OType>(static_cast<CType>(thrd_tile_input[i].data.elt[j]) * scale_data);
} }
...@@ -459,7 +836,7 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose ...@@ -459,7 +836,7 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose
if constexpr (kReturnTranspose) { if constexpr (kReturnTranspose) {
const size_t block_tile_t_start_idx = const size_t block_tile_t_start_idx =
tile_id_x * BLOCK_TILE_DIM * num_rows + tile_id_y * BLOCK_TILE_DIM; tile_id_x * BLOCK_TILE_DIM64 * num_rows + tile_id_y * BLOCK_TILE_DIM64;
const size_t warp_tile_t_start_idx = const size_t warp_tile_t_start_idx =
block_tile_t_start_idx + block_tile_t_start_idx +
warp_id_in_block_x * THREAD_TILE_DIM_X * NUM_THREADS_X_IN_WARP * num_rows + warp_id_in_block_x * THREAD_TILE_DIM_X * NUM_THREADS_X_IN_WARP * num_rows +
...@@ -540,8 +917,9 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -540,8 +917,9 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
scale_t_stride_y = scale_inv_t.shape[1]; scale_t_stride_y = scale_inv_t.shape[1];
} }
const size_t num_blocks_x = DIVUP(row_length, BLOCK_TILE_DIM); const size_t block_len = blockwise_fp8_block_len();
const size_t num_blocks_y = DIVUP(num_rows, BLOCK_TILE_DIM); const size_t num_blocks_x = DIVUP(row_length, block_len);
const size_t num_blocks_y = DIVUP(num_rows, block_len);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.dtype, InputType, input.dtype, InputType,
...@@ -553,8 +931,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -553,8 +931,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
return_transpose, kReturnTranspose, return_transpose, kReturnTranspose,
dim3 grid(num_blocks_x, num_blocks_y, 1); dim3 grid(num_blocks_x, num_blocks_y, 1);
const bool full_tile = const bool full_tile = row_length % block_len == 0 && num_rows % block_len == 0;
row_length % BLOCK_TILE_DIM == 0 && num_rows % BLOCK_TILE_DIM == 0;
if (full_tile) { if (full_tile) {
#ifndef __HIP_PLATFORM_AMD__ #ifndef __HIP_PLATFORM_AMD__
...@@ -573,27 +950,64 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -573,27 +950,64 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon,
tensor_map_output_trans, pow_2_scale); tensor_map_output_trans, pow_2_scale);
#else #else
block_scaled_cast_transpose_kernel<kReturnTranspose, float, InputType, OutputType> while (true) {
if (128 == block_len) {
block_scaled_cast_transpose_kernel<kReturnTranspose, float, InputType,
OutputType>
<<<grid, THREADS_PER_BLOCK, 0, stream>>>( <<<grid, THREADS_PER_BLOCK, 0, stream>>>(
reinterpret_cast<const InputType*>(input.dptr), reinterpret_cast<const InputType*>(input.dptr),
reinterpret_cast<OutputType*>(output.dptr), reinterpret_cast<OutputType*>(output.dptr),
reinterpret_cast<OutputType*>(output_t.dptr), reinterpret_cast<OutputType*>(output_t.dptr),
reinterpret_cast<float*>(scale_inv.dptr), reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows, reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y,
pow_2_scale); epsilon, pow_2_scale);
break;
}
if (64 == block_len) {
block_scaled_block_len64_cast_transpose_kernel<kReturnTranspose, float,
InputType, OutputType>
<<<grid, THREADS_PER_BLOCK64, 0, stream>>>(
reinterpret_cast<const InputType*>(input.dptr),
reinterpret_cast<OutputType*>(output.dptr),
reinterpret_cast<OutputType*>(output_t.dptr),
reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y,
epsilon, pow_2_scale);
break;
}
}
#endif #endif
} else { } else {
block_scaled_cast_transpose_kernel_notaligned<kReturnTranspose, float, InputType, while (true) {
OutputType> if (128 == block_len) {
block_scaled_cast_transpose_kernel_notaligned<kReturnTranspose, float,
InputType, OutputType>
<<<grid, THREADS_PER_BLOCK, 0, stream>>>( <<<grid, THREADS_PER_BLOCK, 0, stream>>>(
reinterpret_cast<const InputType*>(input.dptr), reinterpret_cast<const InputType*>(input.dptr),
reinterpret_cast<OutputType*>(output.dptr), reinterpret_cast<OutputType*>(output.dptr),
reinterpret_cast<OutputType*>(output_t.dptr), reinterpret_cast<OutputType*>(output_t.dptr),
reinterpret_cast<float*>(scale_inv.dptr), reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows, reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y,
pow_2_scale); epsilon, pow_2_scale);
break;
}
if (64 == block_len) {
block_scaled_block_len64_cast_transpose_kernel_notaligned<
kReturnTranspose, float, InputType, OutputType>
<<<grid, THREADS_PER_BLOCK64, 0, stream>>>(
reinterpret_cast<const InputType*>(input.dptr),
reinterpret_cast<OutputType*>(output.dptr),
reinterpret_cast<OutputType*>(output_t.dptr),
reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y,
epsilon, pow_2_scale);
break;
}
}
} // full-tile } // full-tile
) // return_transpose ) // return_transpose
) // OutputType ) // OutputType
......
...@@ -131,6 +131,7 @@ Step 3: Transpose, cast and store to output_t ...@@ -131,6 +131,7 @@ Step 3: Transpose, cast and store to output_t
constexpr size_t kThreadsPerWarp = 32; constexpr size_t kThreadsPerWarp = 32;
// Hyperparameters for performance tuning // Hyperparameters for performance tuning
constexpr int kTileDim64 = 64;
constexpr int kTileDim = 128; // Fixed to 128 beacause we are using 1x128 and 128x1 quantization constexpr int kTileDim = 128; // Fixed to 128 beacause we are using 1x128 and 128x1 quantization
constexpr int kNVecIn = 8; // The number of elements each LDG touches constexpr int kNVecIn = 8; // The number of elements each LDG touches
constexpr int kNVecOut = 16; // The number of elements each STG touches constexpr int kNVecOut = 16; // The number of elements each STG touches
...@@ -148,6 +149,15 @@ constexpr int kNumThreadsStore = kTileDim / kNVecOut; ...@@ -148,6 +149,15 @@ constexpr int kNumThreadsStore = kTileDim / kNVecOut;
static_assert(kNumThreadsLoad <= kThreadsPerWarp, "kNumThreadsLoad must be <= kThreadsPerWarp"); static_assert(kNumThreadsLoad <= kThreadsPerWarp, "kNumThreadsLoad must be <= kThreadsPerWarp");
static_assert(kNumThreadsStore <= kThreadsPerWarp, "kNumThreadsStore must be <= kThreadsPerWarp"); static_assert(kNumThreadsStore <= kThreadsPerWarp, "kNumThreadsStore must be <= kThreadsPerWarp");
constexpr int kSMemRow64 = kTileDim64;
constexpr int kSMemCol64 = (kTileDim64 / kNVecSMem) + 1;
constexpr int kSMemSize64 = kSMemRow64 * kSMemCol64 * kNVecSMem;
constexpr int kNumThreadsLoad64 = kTileDim64 / kNVecIn;
constexpr int kNumThreadsStore64 = kTileDim64 / kNVecOut;
static_assert(kNumThreadsLoad64 <= kThreadsPerWarp, "kNumThreadsLoad64 must be <= kThreadsPerWarp");
static_assert(kNumThreadsStore64 <= kThreadsPerWarp,
"kNumThreadsStore64 must be <= kThreadsPerWarp");
template <bool kAligned, typename CType, typename IType, typename OType> template <bool kAligned, typename CType, typename IType, typename OType>
__global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpose_kernel( __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpose_kernel(
const IType* const input, OType* const output_c, OType* const output_t, const IType* const input, OType* const output_c, OType* const output_t,
...@@ -423,6 +433,290 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -423,6 +433,290 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
} }
} }
template <bool kAligned, typename CType, typename IType, typename OType>
__global__ void __launch_bounds__(kThreadsPerBlock)
block_scaled_block_len64_1d_cast_transpose_kernel(
const IType* const input, OType* const output_c, OType* const output_t,
CType* const tile_scales_inv_c, CType* const tile_scales_inv_t, const size_t row_length,
const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y,
const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon,
FP8BlockwiseRowwiseOption rowwise_option, FP8BlockwiseColumnwiseOption columnwise_option,
const bool pow_2_scaling) {
bool return_rowwise = rowwise_option == FP8BlockwiseRowwiseOption::ROWWISE;
bool return_columnwise_transpose =
columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_TRANSPOSE;
using SMemVec = Vec<IType, kNVecSMem>;
using OVec = Vec<OType, kNVecOut>;
union IVec {
Vec<IType, kNVecIn> input_type;
Vec<SMemVec, kNVecIn / kNVecSMem> smem_type;
};
extern __shared__ char smem_base[];
SMemVec* smem = reinterpret_cast<SMemVec*>(&smem_base[0]);
// Step 1: Load input to shared memory
{
constexpr int r_stride =
kThreadsPerBlock / kNumThreadsLoad64; // stride in rows of shared memory
constexpr int num_iterations = kTileDim64 / r_stride;
const int c_s =
(threadIdx.x % kNumThreadsLoad64) * (kNVecIn / kNVecSMem); // Column in shared memory
int r_s = threadIdx.x / kNumThreadsLoad64; // Row in shared memory
const size_t c_g =
static_cast<size_t>(blockIdx.x) * kTileDim64 + c_s * kNVecSMem; // Column in global memory
size_t r_g = static_cast<size_t>(blockIdx.y) * kTileDim64 + r_s; // Row in global memory
const size_t stride_g = static_cast<size_t>(r_stride) * row_length; // Stride in global memory
const size_t num_ele = c_g < row_length ? min(static_cast<size_t>(kNVecIn), row_length - c_g)
: 0; // For not aligned case
const IType* input_g = &input[r_g * row_length + c_g]; // Input address in global memory
#pragma unroll
for (int iter = 0; iter < num_iterations; ++iter) {
IVec input_vec;
// Step 1.1: Load from global memory (input) to registers
if constexpr (kAligned) {
input_vec.input_type.load_from(input_g);
} else {
if (r_g < num_rows) {
input_vec.input_type.load_from_elts(input_g, 0, num_ele);
} else {
input_vec.input_type.clear();
}
}
// Step 1.2: Write to shared memory
#pragma unroll
for (int i = 0; i < kNVecIn / kNVecSMem; ++i) {
int c = c_s + i;
int r = r_s;
smem[r * kSMemCol64 + c] = input_vec.smem_type.data.elt[i];
}
// Step 1.3: Update input address, row index of shared memory, (and row index of global memory for not aligned case)
input_g += stride_g;
r_s += r_stride;
if constexpr (!kAligned) {
r_g += r_stride;
}
}
}
__syncthreads();
// Step 2: Cast and store to output_c
if (return_rowwise) {
constexpr int r_stride =
kThreadsPerBlock / kNumThreadsStore64; // stride in rows of shared memory
constexpr int num_iterations = kTileDim64 / r_stride;
const int c_s =
(threadIdx.x % kNumThreadsStore64) * (kNVecOut / kNVecSMem); // Column in shared memory
int r_s = threadIdx.x / kNumThreadsStore64; // Row in shared memory
const size_t c_g =
static_cast<size_t>(blockIdx.x) * kTileDim64 + c_s * kNVecSMem; // Column in global memory
size_t r_g = static_cast<size_t>(blockIdx.y) * kTileDim64 + r_s; // Row in global memory
const size_t stride_g = static_cast<size_t>(r_stride) * row_length; // Stride in global memory
const size_t num_ele = c_g < row_length ? min(static_cast<size_t>(kNVecOut), row_length - c_g)
: 0; // For not aligned case
OType* output_g = &output_c[r_g * row_length + c_g]; // Output address in global memory
// Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of
// the first thread to do the reduction.
const unsigned src_lane =
(threadIdx.x % kThreadsPerWarp) / kNumThreadsStore64 * kNumThreadsStore64;
// This mask represents which threads should do the reduction together.
const unsigned mask = ((1 << kNumThreadsStore64) - 1) << src_lane;
const bool is_src_lane = (threadIdx.x % kNumThreadsStore64) == 0;
#pragma unroll
for (int iter = 0; iter < num_iterations; ++iter) {
SMemVec smem_vec[kNVecOut / kNVecSMem];
// Step 2.1: Load from shared memory to registers
#pragma unroll
for (int i = 0; i < kNVecOut / kNVecSMem; ++i) {
int c = c_s + i;
int r = r_s;
smem_vec[i] = smem[r * kSMemCol64 + c];
}
// Step 2.2: Compute local amax
CType amax = 0;
#pragma unroll
for (int i = 0; i < kNVecOut / kNVecSMem; ++i) {
#pragma unroll
for (int j = 0; j < kNVecSMem; ++j) {
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(smem_vec[i].data.elt[j]));
}
}
// Step 2.3: Reduce amax
#pragma unroll
for (int delta = kNumThreadsStore64 / 2; delta > 0; delta /= 2) {
#ifdef __HIP_PLATFORM_AMD__
const float other_amax =
__shfl_down_sync((unsigned long long)(mask), amax, delta, kThreadsPerWarp);
#else
const float other_amax = __shfl_down_sync(mask, amax, delta);
#endif
__builtin_assume(amax >= 0);
__builtin_assume(other_amax >= 0);
amax = fmaxf(amax, other_amax);
}
#ifdef __HIP_PLATFORM_AMD__
amax = __shfl_sync((unsigned long long)(mask), amax, src_lane, kThreadsPerWarp);
#else
amax = __shfl_sync(mask, amax, src_lane);
#endif
CType scale;
// Step 2.4: Compute scale
scale = compute_scale_from_types<IType, OType>(amax, epsilon, pow_2_scaling);
// Step 2.5: Write scale_inv
bool write_scale_inv = is_src_lane;
if constexpr (!kAligned) {
write_scale_inv &= (r_g < num_rows);
}
if (write_scale_inv) {
CType scale_inv = 1.0 / scale;
size_t row_idx = static_cast<size_t>(blockIdx.y) * kTileDim64 + r_s;
size_t col_idx = static_cast<size_t>(blockIdx.x);
tile_scales_inv_c[row_idx * scale_stride_y + col_idx * scale_stride_x] = scale_inv;
}
// Step 2.6: Quantize
OVec output_vec;
#pragma unroll
for (int i = 0; i < kNVecOut / kNVecSMem; ++i) {
#pragma unroll
for (int j = 0; j < kNVecSMem; ++j) {
if constexpr (std::is_same_v<OType, int8_t>) {
output_vec.data.elt[i * kNVecSMem + j] = static_cast<OType>(lroundf(fmaxf(
-127.0f, fminf(127.0f, static_cast<CType>(smem_vec[i].data.elt[j]) * scale))));
} else {
output_vec.data.elt[i * kNVecSMem + j] =
static_cast<OType>(static_cast<CType>(smem_vec[i].data.elt[j]) * scale);
}
}
}
// Step 2.7: Store output_c
if constexpr (kAligned) {
output_vec.store_to(output_g);
} else {
if (r_g < num_rows) {
output_vec.store_to_elts(output_g, 0, num_ele);
}
}
// Step 2.8: Update output address, row index of shared memory (and row index of global memory for not aligned case)
output_g += stride_g;
r_s += r_stride;
if constexpr (!kAligned) {
r_g += r_stride;
}
}
}
// Step 3: Transpose, cast and store to output_t
if (return_columnwise_transpose) {
constexpr int c_stride =
kThreadsPerBlock / kNumThreadsStore64; // Stride in columns of shared memory
constexpr int total_smem_cols = kTileDim64 / kNVecSMem;
constexpr int num_iterations = (total_smem_cols + c_stride - 1) / c_stride;
const int r_s = (threadIdx.x % kNumThreadsStore64) * kNVecOut; // Row in shared memory
int c_s = threadIdx.x / kNumThreadsStore64; // Column in shared memory
size_t r_g =
static_cast<size_t>(blockIdx.x) * kTileDim64 + c_s * kNVecSMem; // Row in global memory
const size_t c_g =
static_cast<size_t>(blockIdx.y) * kTileDim64 + r_s; // Column in global memory
const size_t stride_g =
static_cast<size_t>(c_stride) * kNVecSMem * num_rows; // Stride in global memory
const size_t num_ele = c_g < num_rows ? min(static_cast<size_t>(kNVecOut), num_rows - c_g)
: 0; // For not aligned case
OType* output_g = &output_t[r_g * num_rows + c_g]; // Output address in global memory
// Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of
// the first thread to do the reduction.
const unsigned src_lane =
(threadIdx.x % kThreadsPerWarp) / kNumThreadsStore64 * kNumThreadsStore64;
// This mask represents which threads should do the reduction together.
const unsigned mask = ((1 << kNumThreadsStore64) - 1) << src_lane;
const bool is_src_lane = (threadIdx.x % kNumThreadsStore64) == 0;
#pragma unroll
for (int iter = 0; iter < num_iterations; ++iter) {
if (c_s < total_smem_cols) {
SMemVec smem_vec[kNVecOut];
// Step 3.1: Load from shared memory to registers
#pragma unroll
for (int i = 0; i < kNVecOut; ++i) {
int r = r_s + i;
int c = c_s;
smem_vec[i] = smem[r * kSMemCol64 + c];
}
#pragma unroll
for (int smem_idx = 0; smem_idx < kNVecSMem; ++smem_idx) {
// Step 3.2: Compute local amax
CType amax = 0;
#pragma unroll
for (int i = 0; i < kNVecOut; ++i) {
amax = fmaxf(amax, fabsf(smem_vec[i].data.elt[smem_idx]));
}
// Step 3.3: Reduce amax
#pragma unroll
for (int delta = kNumThreadsStore64 / 2; delta > 0; delta /= 2) {
#ifdef __HIP_PLATFORM_AMD__
const float other_amax =
__shfl_down_sync((unsigned long long)(mask), amax, delta, kThreadsPerWarp);
#else
const float other_amax = __shfl_down_sync(mask, amax, delta);
#endif
__builtin_assume(amax >= 0);
__builtin_assume(other_amax >= 0);
amax = fmaxf(amax, other_amax);
}
#ifdef __HIP_PLATFORM_AMD__
amax = __shfl_sync((unsigned long long)(mask), amax, src_lane, kThreadsPerWarp);
#else
amax = __shfl_sync(mask, amax, src_lane);
#endif
// Step 3.4: Compute scale
CType scale;
scale = compute_scale_from_types<IType, OType>(amax, epsilon, pow_2_scaling);
// Step 3.5: Write scale_inv_t
bool write_scale_inv = is_src_lane;
if constexpr (!kAligned) {
write_scale_inv &= (r_g + smem_idx < row_length);
}
if (write_scale_inv) {
CType scale_inv = 1.0 / scale;
size_t row_idx =
static_cast<size_t>(blockIdx.x) * kTileDim64 + c_s * kNVecSMem + smem_idx;
size_t col_idx = static_cast<size_t>(blockIdx.y);
tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv;
}
// Step 3.6: Quantize
OVec output_vec;
#pragma unroll
for (int i = 0; i < kNVecOut; ++i) {
if constexpr (std::is_same_v<OType, int8_t>) {
output_vec.data.elt[i] = static_cast<OType>(lroundf(fmaxf(
-127.0f,
fminf(127.0f, static_cast<CType>(smem_vec[i].data.elt[smem_idx]) * scale))));
} else {
output_vec.data.elt[i] =
static_cast<OType>(static_cast<CType>(smem_vec[i].data.elt[smem_idx]) * scale);
}
}
// Step 3.7: Store output_t
if constexpr (kAligned) {
output_vec.store_to(output_g + smem_idx * num_rows);
} else {
if (r_g + smem_idx < row_length) {
output_vec.store_to_elts(output_g + smem_idx * num_rows, 0, num_ele);
}
}
}
}
// Step 3.8: Update output address, column index of shared memory (and row index of global memory for not aligned case)
output_g += stride_g;
c_s += c_stride;
if constexpr (!kAligned) {
r_g += c_stride * kNVecSMem;
}
}
}
}
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
constexpr int kFP32SMemCol = kTileDim / kNVecSMem; constexpr int kFP32SMemCol = kTileDim / kNVecSMem;
constexpr int kFP32SMemSize = kSMemRow * kFP32SMemCol * kNVecSMem; constexpr int kFP32SMemSize = kSMemRow * kFP32SMemCol * kNVecSMem;
...@@ -771,8 +1065,9 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -771,8 +1065,9 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
scale_t_stride_y = 1; scale_t_stride_y = 1;
} }
const size_t num_blocks_x = DIVUP(row_length, (size_t)kTileDim); const size_t block_len = blockwise_fp8_block_len();
const size_t num_blocks_y = DIVUP(num_rows, (size_t)kTileDim); const size_t num_blocks_x = DIVUP(row_length, (size_t)block_len);
const size_t num_blocks_y = DIVUP(num_rows, (size_t)block_len);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.dtype, InputType, input.dtype, InputType,
...@@ -782,35 +1077,38 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -782,35 +1077,38 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
dim3 grid(num_blocks_x, num_blocks_y, 1); dim3 grid(num_blocks_x, num_blocks_y, 1);
const bool full_tile = row_length % kTileDim == 0 && num_rows % kTileDim == 0; const bool full_tile = row_length % block_len == 0 && num_rows % block_len == 0;
TRANSFORMER_ENGINE_SWITCH_CONDITION( TRANSFORMER_ENGINE_SWITCH_CONDITION(
full_tile, kAligned, full_tile, kAligned,
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
while (true) {
if (128 == block_len) {
if constexpr (std::is_same_v<InputType, float>) { if constexpr (std::is_same_v<InputType, float>) {
size_t smem_bytes = kFP32SMemSize * sizeof(InputType); size_t smem_bytes = kFP32SMemSize * sizeof(InputType);
if (smem_bytes >= 48 * 1024) { if (smem_bytes >= 48 * 1024) {
cudaError_t err = cudaError_t err = cudaFuncSetAttribute(
cudaFuncSetAttribute((const void*)&block_scaled_1d_cast_transpose_kernel_fp32< (const void*)&block_scaled_1d_cast_transpose_kernel_fp32<
kAligned, float, InputType, OutputType>, kAligned, float, InputType, OutputType>,
cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
} }
block_scaled_1d_cast_transpose_kernel_fp32<kAligned, float, InputType, OutputType> block_scaled_1d_cast_transpose_kernel_fp32<kAligned, float, InputType,
OutputType>
<<<grid, kThreadsPerBlock, smem_bytes, stream>>>( <<<grid, kThreadsPerBlock, smem_bytes, stream>>>(
reinterpret_cast<const InputType*>(input.dptr), reinterpret_cast<const InputType*>(input.dptr),
reinterpret_cast<OutputType*>(output.dptr), reinterpret_cast<OutputType*>(output.dptr),
reinterpret_cast<OutputType*>(output_t.dptr), reinterpret_cast<OutputType*>(output_t.dptr),
reinterpret_cast<float*>(scale_inv.dptr), reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows, reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y,
rowwise_option, columnwise_option, pow2_scale); epsilon, rowwise_option, columnwise_option, pow2_scale);
} else { } else {
size_t smem_bytes = kSMemSize * sizeof(InputType); size_t smem_bytes = kSMemSize * sizeof(InputType);
if (smem_bytes >= 48 * 1024) { if (smem_bytes >= 48 * 1024) {
cudaError_t err = cudaFuncSetAttribute( cudaError_t err = cudaFuncSetAttribute(
(const void*)&block_scaled_1d_cast_transpose_kernel<kAligned, float, (const void*)&block_scaled_1d_cast_transpose_kernel<
InputType, OutputType>, kAligned, float, InputType, OutputType>,
cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
} }
block_scaled_1d_cast_transpose_kernel<kAligned, float, InputType, OutputType> block_scaled_1d_cast_transpose_kernel<kAligned, float, InputType, OutputType>
...@@ -820,8 +1118,31 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -820,8 +1118,31 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
reinterpret_cast<OutputType*>(output_t.dptr), reinterpret_cast<OutputType*>(output_t.dptr),
reinterpret_cast<float*>(scale_inv.dptr), reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows, reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y,
rowwise_option, columnwise_option, pow2_scale); epsilon, rowwise_option, columnwise_option, pow2_scale);
}
break;
}
if (64 == block_len) {
size_t smem_bytes = kSMemSize64 * sizeof(InputType);
if (smem_bytes >= 48 * 1024) {
cudaError_t err = cudaFuncSetAttribute(
(const void*)&block_scaled_block_len64_1d_cast_transpose_kernel<
kAligned, float, InputType, OutputType>,
cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
}
block_scaled_block_len64_1d_cast_transpose_kernel<kAligned, float, InputType,
OutputType>
<<<grid, kThreadsPerBlock, smem_bytes, stream>>>(
reinterpret_cast<const InputType*>(input.dptr),
reinterpret_cast<OutputType*>(output.dptr),
reinterpret_cast<OutputType*>(output_t.dptr),
reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y,
epsilon, rowwise_option, columnwise_option, pow2_scale);
break;
}
} }
#else #else
size_t smem_bytes = kSMemSize * sizeof(InputType); size_t smem_bytes = kSMemSize * sizeof(InputType);
......
...@@ -15,6 +15,7 @@ from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt_wgrad import w8a8_ ...@@ -15,6 +15,7 @@ from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt_wgrad import w8a8_
from ..tensor.quantized_tensor import Quantizer from ..tensor.quantized_tensor import Quantizer
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ...debug.pytorch.debug_quantization import DebugQuantizer from ...debug.pytorch.debug_quantization import DebugQuantizer
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
int8_simulation_fp8 = bool(int(os.getenv("NVTE_INT8_SIM_FP8", "0"))) int8_simulation_fp8 = bool(int(os.getenv("NVTE_INT8_SIM_FP8", "0")))
__all__ = [ __all__ = [
...@@ -76,7 +77,7 @@ def general_gemm( ...@@ -76,7 +77,7 @@ def general_gemm(
ref_scales_w = A._rowwise_scale_inv ref_scales_w = A._rowwise_scale_inv
y, _ = w8a8_block_int8_matmul( y, _ = w8a8_block_int8_matmul(
qx_data, qw_data, ref_scales_x, ref_scales_w, [128, 128], qx_data, qw_data, ref_scales_x, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype output_dtype=out_dtype
) )
return y, None, None, None return y, None, None, None
...@@ -92,7 +93,7 @@ def general_gemm( ...@@ -92,7 +93,7 @@ def general_gemm(
ref_scales_w = A._columnwise_scale_inv ref_scales_w = A._columnwise_scale_inv
y, _ = w8a8_block_int8_matmul( y, _ = w8a8_block_int8_matmul(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [128, 128], qdout_data, qw_data, ref_scales_dout, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype output_dtype=out_dtype
) )
return y, None, None, None return y, None, None, None
...@@ -108,7 +109,7 @@ def general_gemm( ...@@ -108,7 +109,7 @@ def general_gemm(
ref_scales_x = A._columnwise_scale_inv ref_scales_x = A._columnwise_scale_inv
out, _ = w8a8_block_int8_matmul_wgrad( out, _ = w8a8_block_int8_matmul_wgrad(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate, [128, 128], qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate, [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype output_dtype=out_dtype
) )
return out, None, None, None return out, None, None, None
...@@ -243,7 +244,7 @@ def general_grouped_gemm( ...@@ -243,7 +244,7 @@ def general_grouped_gemm(
seq_len = sum(m_splits) // num_gemms seq_len = sum(m_splits) // num_gemms
out[0] = w8a8_block_int8_matmul_batched( out[0] = w8a8_block_int8_matmul_batched(
qx_data, qw_data, ref_scales_x, ref_scales_w, out[0].view(num_gemms, seq_len, out[0].size(-1)), [128, 128], qx_data, qw_data, ref_scales_x, ref_scales_w, out[0].view(num_gemms, seq_len, out[0].size(-1)), [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype output_dtype=out_dtype
) )
return out, bias, gelu_input return out, bias, gelu_input
...@@ -262,7 +263,7 @@ def general_grouped_gemm( ...@@ -262,7 +263,7 @@ def general_grouped_gemm(
seq_len = sum(m_splits) // num_gemms seq_len = sum(m_splits) // num_gemms
out[0] = w8a8_block_int8_matmul_batched( out[0] = w8a8_block_int8_matmul_batched(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, out[0].view(num_gemms, seq_len, out[0].size(-1)), [128, 128], qdout_data, qw_data, ref_scales_dout, ref_scales_w, out[0].view(num_gemms, seq_len, out[0].size(-1)), [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype output_dtype=out_dtype
) )
return out, bias, gelu_input return out, bias, gelu_input
...@@ -278,7 +279,7 @@ def general_grouped_gemm( ...@@ -278,7 +279,7 @@ def general_grouped_gemm(
ref_scales_x = [a._columnwise_scale_inv for a in A] ref_scales_x = [a._columnwise_scale_inv for a in A]
out = w8a8_block_int8_matmul_wgrad_batched_native( out = w8a8_block_int8_matmul_wgrad_batched_native(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate, [128, 128], qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate, [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype output_dtype=out_dtype
) )
return out, bias, gelu_input return out, bias, gelu_input
......
...@@ -48,6 +48,8 @@ ...@@ -48,6 +48,8 @@
#include <cassert> #include <cassert>
#include <cstring> #include <cstring>
#include <iostream> #include <iostream>
#include <string>
#include <sstream>
#include <memory> #include <memory>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp> #include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <vector> #include <vector>
...@@ -60,6 +62,18 @@ namespace transformer_engine::pytorch { ...@@ -60,6 +62,18 @@ namespace transformer_engine::pytorch {
// in python we have: dist_group_type = torch.distributed.ProcessGroup // in python we have: dist_group_type = torch.distributed.ProcessGroup
using dist_group_type = c10d::ProcessGroup; using dist_group_type = c10d::ProcessGroup;
inline int blockwise_fp8_block_len() {
const char *env = std::getenv("NVTE_BLOCKWISE_FP8_BLOCK_LEN");
if (env == nullptr || env[0] == '\0') {
return 128;
}
int value;
std::istringstream iss(env);
iss >> value;
NVTE_CHECK(iss, "Invalid environment variable value");
return value;
}
// Each tensor here is shape (N, ) holding all scaling // Each tensor here is shape (N, ) holding all scaling
// data for a single FP8 block, e.g. LayerNormLinear // data for a single FP8 block, e.g. LayerNormLinear
class FP8TensorMeta { class FP8TensorMeta {
......
...@@ -10,7 +10,7 @@ namespace transformer_engine::pytorch { ...@@ -10,7 +10,7 @@ namespace transformer_engine::pytorch {
void fp8_block_scaling_compute_partial_amax(const at::Tensor &tensor, at::Tensor amax, size_t h, void fp8_block_scaling_compute_partial_amax(const at::Tensor &tensor, at::Tensor amax, size_t h,
size_t w, size_t start_offset, size_t block_len) { size_t w, size_t start_offset, size_t block_len) {
TORCH_CHECK(block_len == 128, "Currently only block_len = 128 is supported"); TORCH_CHECK(block_len == 128 || block_len == 64, "Currently only block_len = 128 or 64 is supported");
TORCH_CHECK(amax.dim() == 2, "amax must be a 2D tensor"); TORCH_CHECK(amax.dim() == 2, "amax must be a 2D tensor");
TORCH_CHECK(amax.scalar_type() == at::ScalarType::Float, "amax must be a float tensor"); TORCH_CHECK(amax.scalar_type() == at::ScalarType::Float, "amax must be a float tensor");
TORCH_CHECK(tensor.scalar_type() == at::ScalarType::Float || TORCH_CHECK(tensor.scalar_type() == at::ScalarType::Float ||
...@@ -28,7 +28,7 @@ void fp8_block_scaling_compute_partial_amax(const at::Tensor &tensor, at::Tensor ...@@ -28,7 +28,7 @@ void fp8_block_scaling_compute_partial_amax(const at::Tensor &tensor, at::Tensor
void fp8_block_scaling_partial_cast(const at::Tensor &inp, at::Tensor out, const at::Tensor &scale, void fp8_block_scaling_partial_cast(const at::Tensor &inp, at::Tensor out, const at::Tensor &scale,
size_t h, size_t w, size_t start_offset, size_t block_len, size_t h, size_t w, size_t start_offset, size_t block_len,
const transformer_engine::DType out_dtype) { const transformer_engine::DType out_dtype) {
TORCH_CHECK(block_len == 128, "Currently only block_len = 128 is supported"); TORCH_CHECK(block_len == 128 || block_len == 64, "Currently only block_len = 128 or 64 is supported");
TORCH_CHECK(scale.dim() == 2, "scale must be a 2D tensor"); TORCH_CHECK(scale.dim() == 2, "scale must be a 2D tensor");
TORCH_CHECK(scale.scalar_type() == at::ScalarType::Float, "scale must be a float tensor"); TORCH_CHECK(scale.scalar_type() == at::ScalarType::Float, "scale must be a float tensor");
TORCH_CHECK( TORCH_CHECK(
......
...@@ -297,7 +297,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor( ...@@ -297,7 +297,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
size_t k_dim = torch_shape.size() == 0 ? 1u : torch_shape.back(); size_t k_dim = torch_shape.size() == 0 ? 1u : torch_shape.back();
size_t m_dim = numel / k_dim; size_t m_dim = numel / k_dim;
constexpr size_t kBlockLen = 128; size_t kBlockLen = static_cast<size_t>(blockwise_fp8_block_len());
if (rowwise_usage) { if (rowwise_usage) {
if (rowwise_data.has_value()) { if (rowwise_data.has_value()) {
......
...@@ -1018,7 +1018,7 @@ def _all_gather_fp8_blockwise( ...@@ -1018,7 +1018,7 @@ def _all_gather_fp8_blockwise(
# Check that quantizer is valid # Check that quantizer is valid
if quantizer is not None and not isinstance(quantizer, Float8BlockQuantizer): if quantizer is not None and not isinstance(quantizer, Float8BlockQuantizer):
raise ValueError(f"Got non-FP8 blockwise quantizer ({quantizer.__class__.__name__})") raise ValueError(f"Got non-FP8 blockwise quantizer ({quantizer.__class__.__name__})")
if not (quantizer.block_scaling_dim == 1 and quantizer.block_len == 128): if not (quantizer.block_scaling_dim == 1 and (quantizer.block_len == 128 or quantizer.block_len == 64)):
raise NotImplementedError("Only 1D blockwise quantization is supported for allgather") raise NotImplementedError("Only 1D blockwise quantization is supported for allgather")
# Output tensor dims # Output tensor dims
......
...@@ -28,6 +28,7 @@ from .utils import get_device_compute_capability ...@@ -28,6 +28,7 @@ from .utils import get_device_compute_capability
from .jit import jit_fuser from .jit import jit_fuser
from torch.utils.cpp_extension import IS_HIP_EXTENSION from torch.utils.cpp_extension import IS_HIP_EXTENSION
int8_simulation_fp8 = bool(int(os.getenv("NVTE_INT8_SIM_FP8", "0"))) int8_simulation_fp8 = bool(int(os.getenv("NVTE_INT8_SIM_FP8", "0")))
blockwise_fp8_block_len = int(os.getenv("NVTE_BLOCKWISE_FP8_BLOCK_LEN", "128"))
__all__ = ["fp8_autocast", "fp8_model_init"] __all__ = ["fp8_autocast", "fp8_model_init"]
......
...@@ -11,6 +11,7 @@ import torch ...@@ -11,6 +11,7 @@ import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType from transformer_engine_torch import DType as TE_DType
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
from ..quantized_tensor import QuantizedTensorBase from ..quantized_tensor import QuantizedTensorBase
...@@ -125,7 +126,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase): ...@@ -125,7 +126,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
return torch.permute(columnwise_dq, tuple(permute_dims)).contiguous() return torch.permute(columnwise_dq, tuple(permute_dims)).contiguous()
def _dequantize_vectorwise(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: def _dequantize_vectorwise(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor:
block_len = 128 block_len = blockwise_fp8_block_len
q_M, q_K = 1, 1 q_M, q_K = 1, 1
if self._rowwise_data is not None: if self._rowwise_data is not None:
...@@ -178,7 +179,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase): ...@@ -178,7 +179,7 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
""" """
Construct plain PyTorch tensor from Float8BlockwiseQTensor Construct plain PyTorch tensor from Float8BlockwiseQTensor
""" """
block_len = 128 block_len = blockwise_fp8_block_len
if not self._is_2D_scaled: if not self._is_2D_scaled:
return self._dequantize_vectorwise(dtype=dtype) return self._dequantize_vectorwise(dtype=dtype)
......
...@@ -14,6 +14,7 @@ from transformer_engine_torch import DType as TE_DType ...@@ -14,6 +14,7 @@ from transformer_engine_torch import DType as TE_DType
from ._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from ._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc
from ..utils import devices_match, round_up_to_nearest_multiple from ..utils import devices_match, round_up_to_nearest_multiple
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
aten = torch.ops.aten aten = torch.ops.aten
...@@ -46,7 +47,7 @@ class Float8BlockQuantizer(Quantizer): ...@@ -46,7 +47,7 @@ class Float8BlockQuantizer(Quantizer):
) -> None: ) -> None:
super().__init__(rowwise=rowwise, columnwise=columnwise) super().__init__(rowwise=rowwise, columnwise=columnwise)
self.dtype = tex.DType.kInt8 if int8_simulation_fp8 else fp8_dtype self.dtype = tex.DType.kInt8 if int8_simulation_fp8 else fp8_dtype
self.block_len = 128 self.block_len = blockwise_fp8_block_len
self.force_pow_2_scales = force_pow_2_scales self.force_pow_2_scales = force_pow_2_scales
self.amax_epsilon = amax_epsilon self.amax_epsilon = amax_epsilon
self.block_scaling_dim = block_scaling_dim self.block_scaling_dim = block_scaling_dim
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment