Unverified Commit a8f0fe03 authored by kwyss-nvidia's avatar kwyss-nvidia Committed by GitHub
Browse files

Blockwise scaling linear quantization recipe (#1559)



* Add GEMM logic for blockwise quantized tensors.

GEMM test cases included in pytorch integration.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update NVTE_BLOCK_SCALING for GEMM.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Gate feature on CUDA 12.9
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Gemm typo.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Remove unecessary type converter change.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Reflect epilogue availability and test supported epilogues.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* GEMM simplifications from recipe branch.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Format py code.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update GEMM DGelu tests to match support depending on output dtype.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Force pow2Scales in GEMM
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Add GEMM test to pytorch test suite.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Add copyright to GEMM test.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update import for GEMM test.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Add license.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update test gemm supported predicate.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Use sgemm like interfaces and naming.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Rewrite GEMM comment.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* MR Feedback.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Recipe setup for Linear modules.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Use 12.9 feature test.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Run against tensor dumps from internal library.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update FIXME to TODO with linked issue.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update full recompute feature to save recipe.

The recompute context uses the same recipe
and fp8 settings as the original fwd pass.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* MR Feedback. Avoid reusing quantizer objects.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update logic in module.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Format py.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update for PP bug.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update test numerics.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update force_power_of_2 scales in the recipe.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update usage method to satisfy upstream changes.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* fix subchannel recipe in distributed test with bf16 gather
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* Edit and cleanup BF16 gather code.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update test import.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* support columnwise only mode to 1D quantize kernel
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* Format and move enum
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Skip alloc.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* try async bf16 gather
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* Format python code.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Document and type code.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update pytorch lint errors.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Dont set high precision dtype.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Add test for sanity and CG; fix CG for sequential?
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Keep make_quantizers API stable

Update num_quantizers instead to pass cuda_graph tests.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Fix import name.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Rename recipe method.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Skip grouped linear sanity test.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Set usage before BF16 gather.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* refactor for nvte_quantize_v2
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* Format code.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Cleanup nvte_quantize_v2
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Test fp32 scales.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Disable CUDA graph.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Simplify layernorm linear
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Cleanup layernorm linear.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* LayerNorm linear bwd gather logic.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Communication updates.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update transformer_engine/pytorch/ops/op.py

Apply MR comment change.
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarkwyss-nvidia <kwyss@nvidia.com>

* Lint fix.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* MR feedback.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Enable cuda graph tests.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Reduce chance of spurious failure and reword.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Review suggestions from @timmoon10
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Update CPP tests.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update common.h
Signed-off-by: default avatarXin Yao <yaox12@outlook.com>

* Update test_float8blockwisetensor.py
Signed-off-by: default avatarXin Yao <yaox12@outlook.com>

---------
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarkwyss-nvidia <kwyss@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarXin Yao <yaox12@outlook.com>
Co-authored-by: default avatarzhongboz <zhongboz@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarXin Yao <yaox12@outlook.com>
parent 0da60449
......@@ -19,6 +19,12 @@ using namespace test;
namespace {
struct QuantizationOptions {
bool force_pow_2_scales = false;
float amax_epsilon = 0.0;
size_t block_scaling_dim = 2u;
};
constexpr size_t kBlockLen = 128;
enum ProcessingMethod {
......@@ -273,7 +279,7 @@ void runTestCase(const ProcessingMethod processing_method, const std::vector<siz
Tensor input("input", shape, itype);
Tensor grad("grad", shape, itype);
Tensor output_c("output_c", shape, otype, rowwise, colwise,
opts.block_scaling_dim == 2 ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D, &opts);
opts.block_scaling_dim == 2 ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D);
Tensor output_dbias("output_dbias", {cols}, itype);
std::unique_ptr<OutputType[]> ref_output = std::make_unique<OutputType[]>(rows * cols);
......@@ -293,10 +299,13 @@ void runTestCase(const ProcessingMethod processing_method, const std::vector<siz
fillCase<EncodingType>(&input, fill_case);
fillUniform(&grad);
QuantizationConfigWrapper quant_config;
quant_config.set_force_pow_2_scales(opts.force_pow_2_scales);
quant_config.set_amax_epsilon(opts.amax_epsilon);
Tensor workspace;
switch (processing_method) {
case ProcessingMethod::CAST_ONLY: {
nvte_quantize(input.data(), output_c.data(), 0);
nvte_quantize_v2(input.data(), output_c.data(), quant_config, nullptr);
break;
}
}
......@@ -345,7 +354,7 @@ void runTestCaseOneDimensionalBlocks(const ProcessingMethod processing_method,
Tensor input("input", shape, itype);
Tensor grad("grad", shape, itype);
Tensor output_c("output_c", shape, otype, rowwise, colwise,
opts.block_scaling_dim == 2 ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D, &opts);
opts.block_scaling_dim == 2 ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D);
Tensor output_dbias("output_dbias", {cols}, itype);
std::unique_ptr<OutputType[]> ref_output = std::make_unique<OutputType[]>(rows * cols);
......@@ -366,9 +375,12 @@ void runTestCaseOneDimensionalBlocks(const ProcessingMethod processing_method,
fillUniform(&grad);
Tensor workspace;
QuantizationConfigWrapper quant_config;
quant_config.set_force_pow_2_scales(opts.force_pow_2_scales);
quant_config.set_amax_epsilon(opts.amax_epsilon);
switch (processing_method) {
case ProcessingMethod::CAST_ONLY: {
nvte_quantize(input.data(), output_c.data(), 0);
nvte_quantize_v2(input.data(), output_c.data(), quant_config, nullptr);
break;
}
}
......@@ -399,9 +411,9 @@ void runTestCaseOneDimensionalBlocks(const ProcessingMethod processing_method,
}
std::vector<std::vector<size_t>> matrix_sizes = {
{1, 16}, {16, 48}, {65, 96}, {128, 128}, {256, 256}, {993, 512},
{256, 65536}, {2048, 6144}, {16384, 128}, {32768, 160}, {4096, 1632}, {1024, 1},
{32, 1024}, {16, 512}, {1024}, {8, 32, 1024}, {16, 8, 4, 512},
{1, 16}, {65, 96}, {256, 256}, {993, 512},
{256, 65536}, {4096, 1632}, {1024, 1},
{16, 512}, {1024}, {8, 32, 1024}, {16, 8, 4, 512},
};
std::vector<InputsFillCase> input_scenarios = {
......@@ -429,6 +441,8 @@ std::vector<ActivationType> Activation_types = {
std::vector<float> amax_epsilons = {
0.0f,
1.0f, // Make large to be observable.
};
} // namespace
......@@ -599,7 +613,7 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
::testing::ValuesIn(input_scenarios), ::testing::Values(true, false),
::testing::ValuesIn(amax_epsilons), ::testing::Values(true)),
::testing::ValuesIn(amax_epsilons), ::testing::Values(true, false)),
[](const testing::TestParamInfo<FusedCastFloat8BlockwiseTestSuite::ParamType>& info) {
std::string name =
to_string(std::get<0>(info.param)) + "X" + to_string(std::get<1>(info.param));
......@@ -623,7 +637,7 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
::testing::ValuesIn(input_scenarios), ::testing::Values(true, false),
::testing::ValuesIn(amax_epsilons), ::testing::Values(true)),
::testing::ValuesIn(amax_epsilons), ::testing::Values(true, false)),
[](const testing::TestParamInfo<FusedCastFloat8VectorwiseTestSuite::ParamType>& info) {
std::string name =
to_string(std::get<0>(info.param)) + "X" + to_string(std::get<1>(info.param));
......
......@@ -216,8 +216,7 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
Tensor::Tensor(const std::string& name,
const NVTEShape &shape, const DType type,
const bool rowwise, const bool columnwise,
const NVTEScalingMode &scaling_mode,
const QuantizationOptions* q_opts) {
const NVTEScalingMode &scaling_mode) {
name_ = name;
const size_t seed = create_seed_from_tensor_name(name);
gen_.seed(seed);
......@@ -328,10 +327,6 @@ Tensor::Tensor(const std::string& name,
tensor_.set_columnwise_scale_inv(columnwise_scale_inv, scale_dtype, columnwise_scale_shape);
}
}
if (q_opts != nullptr) {
NVTE_CHECK(q_opts->force_pow_2_scales, "Pow2 scales is required for current implementation.");
NVTE_CHECK(q_opts->amax_epsilon == 0.0, "Amax epsilon must be zero for current implementation.");
}
}
}
......
......@@ -95,29 +95,21 @@ struct TypeInfo{
constexpr static size_t size = sizeof(T);
};
struct QuantizationOptions {
bool force_pow_2_scales = false;
float amax_epsilon = 0.0;
size_t block_scaling_dim = 2u;
};
class Tensor {
public:
Tensor(const std::string& name,
const NVTEShape &shape, const DType type,
const bool rowwise = true,
const bool columnwise = false,
const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING,
const QuantizationOptions* q_opts = nullptr);
const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING);
Tensor(const std::string& name,
const std::vector<size_t> &shape,
const DType type,
const bool rowwise = true,
const bool columnwise = false,
const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING,
const QuantizationOptions* q_opts = nullptr) :
Tensor(name, NVTEShape{shape.data(), shape.size()}, type, rowwise, columnwise, mode, q_opts) {}
const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING) :
Tensor(name, NVTEShape{shape.data(), shape.size()}, type, rowwise, columnwise, mode) {}
Tensor() {}
......
......@@ -19,6 +19,7 @@ from transformer_engine.common.recipe import (
MXFP8BlockScaling,
DelayedScaling,
Float8CurrentScaling,
Float8BlockScaling,
Format,
Recipe,
)
......@@ -49,6 +50,8 @@ def quantization_recipe() -> Recipe:
return MXFP8BlockScaling()
if QUANTIZATION == "fp8_cs":
return Float8CurrentScaling()
if QUANTIZATION == "fp8_block_scaling":
return Float8BlockScaling()
return te.fp8.get_default_fp8_recipe()
......@@ -85,7 +88,7 @@ def main(argv=None, namespace=None):
# Quantization scheme
QUANTIZATION = args.quantization
if QUANTIZATION in ("fp8", "mxfp8"):
if QUANTIZATION in ("fp8", "mxfp8", "fp8_block_scaling"):
global SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE
SEQ_LEN = 32
BATCH_SIZE = 32
......
......@@ -28,6 +28,9 @@ if torch.cuda.device_count() < 2:
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
)
TEST_ROOT = Path(__file__).parent.resolve()
NUM_PROCS: int = min(4, torch.cuda.device_count())
......@@ -48,7 +51,7 @@ def _run_test(quantization):
all_boolean = [True, False]
@pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8", "fp8_cs"])
@pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8", "fp8_cs", "fp8_block_scaling"])
def test_distributed(quantization):
if quantization == "fp8" and not fp8_available:
pytest.skip(reason_for_no_fp8)
......@@ -56,4 +59,6 @@ def test_distributed(quantization):
pytest.skip(fp8_available)
if quantization == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if quantization == "fp8_block_scaling" and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
_run_test(quantization)
......@@ -27,6 +27,9 @@ from transformer_engine.common import recipe
# Check if FP8 is supported.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
)
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
......@@ -55,6 +58,7 @@ fp8_recipes = [
recipe.DelayedScaling(),
recipe.MXFP8BlockScaling(),
recipe.Float8CurrentScaling(),
recipe.Float8BlockScaling(),
]
# Supported data types
......@@ -316,9 +320,13 @@ def test_make_graphed_callables(
pytest.skip("FP8 needed for FP8 parameters.")
if fp8_weight_caching and not fp8:
pytest.skip("FP8 needed for FP8 parameters.")
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_recipe.float8_block_scaling() and module == "linear_op":
pytest.skip("Module not yet supported for float8_block_scaling with CUDA graphs")
# Run model with different CUDA graph settings.
model_config = model_configs[model_config]
kwargs = dict(
......
......@@ -8,21 +8,18 @@ import transformer_engine as te
import transformer_engine_torch as tex
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
Float8BlockQuantizer,
Float8BlockwiseQTensor,
)
from transformer_engine.pytorch.utils import get_device_compute_capability
from references.blockwise_quantizer_reference import CuBLASScaleMunger
from references.blockwise_fp8_gemm_reference import CuBLASRefBlockwiseGemm
def fp8_blockwise_gemm_supported() -> bool:
return (
get_device_compute_capability() >= (9, 0)
and get_device_compute_capability() < (10, 0)
and float(torch.version.cuda) >= 12.9
)
supported, _ = FP8GlobalStateManager.is_fp8_block_scaling_available()
return supported
def cublas_gemm_fp8_blockwise_case(
......
......@@ -4,11 +4,14 @@
from typing import Tuple
import math
import os
import pathlib
import pytest
import torch
import transformer_engine as te
import transformer_engine_torch as tex
from transformer_engine.pytorch.utils import get_device_compute_capability
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.common.recipe import Float8BlockScaling
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
Float8BlockQuantizer,
......@@ -18,10 +21,29 @@ from references.blockwise_quantizer_reference import (
BlockwiseQuantizerReference,
QuantizeResult,
)
from test_float8_current_scaling_exact import (
TestFP8RecipeLinearBase,
TestFP8RecipeLayerNormLinearBase,
)
# read env variable NVTE_TEST_FLOAT8_BLOCK_SCALING_EXACT_TENSOR_DUMP_DIR to override the default tensor dump directory
TENSOR_DUMP_DIR = pathlib.Path(__file__).resolve().parent.parent.parent / "tensor_dumps"
tensor_dump_dir_env = os.getenv("NVTE_TEST_BLOCK_CURRENT_SCALING_EXACT_TENSOR_DUMP_DIR")
if tensor_dump_dir_env is not None:
TENSOR_DUMP_DIR = pathlib.Path(tensor_dump_dir_env)
recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_fp8_block_scaling_available()
class GetRecipes:
@staticmethod
def none():
return None
# TODO replace with call to fp8.py when recipe added.
recipe_available = get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.8
reason_for_no_recipe = "Quantize kernels require TMA and are only relevant with GEMMS."
@staticmethod
def fp8_blockwise():
# return default configs
return Float8BlockScaling()
def initialize_for_many_scales(
......@@ -66,35 +88,7 @@ def initialize_for_many_scales(
return result
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
[
# full tile cases
(128, 128),
(256, 256),
(256, 1024),
(1024, 256),
# Padding required cases
(256, 272),
(303, 300),
(305, 256),
# Some larger tiles.
(2000, 2000),
(2048, 2000),
(2000, 1024),
(2048, 1024),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str)
@pytest.mark.parametrize("eps", [0], ids=["eps_0"])
@pytest.mark.parametrize(
"return_transpose", [True, False], ids=["quantize_transpose", "quantize_only"]
)
@pytest.mark.parametrize("pow_2_scales", [True], ids=["pow2scales"])
@pytest.mark.parametrize("tile_size", [(1, 128), (128, 128)], ids=["1DTile", "2DTile"])
def test_quantization_block_tiling_versus_reference(
def check_quantization_block_tiling_versus_reference(
x_dtype: torch.dtype,
M: int,
N: int,
......@@ -199,12 +193,90 @@ def test_quantization_block_tiling_versus_reference(
[
# full tile cases
(128, 128),
(256, 256),
(256, 1024),
(1024, 256),
# Padding required cases
(256, 272),
(303, 300),
(305, 256),
# Some larger tiles.
(2000, 2000),
(2048, 2000),
(2000, 1024),
(2048, 1024),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str)
@pytest.mark.parametrize("eps", [0], ids=["eps_0"])
@pytest.mark.parametrize(
"return_transpose", [True, False], ids=["quantize_transpose", "quantize_only"]
)
@pytest.mark.parametrize("pow_2_scales", [True], ids=["pow2scales"])
@pytest.mark.parametrize("tile_size", [(1, 128), (128, 128)], ids=["1DTile", "2DTile"])
def test_quantization_block_tiling_versus_reference(
x_dtype: torch.dtype,
M: int,
N: int,
quant_dtype: torch.dtype,
eps: float,
return_transpose: bool,
pow_2_scales: bool,
tile_size: Tuple[int, int],
) -> None:
check_quantization_block_tiling_versus_reference(
x_dtype, M, N, quant_dtype, eps, return_transpose, pow_2_scales, tile_size
)
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
[
# full tile cases
(256, 256),
(2048, 1024),
# Padding required cases
(256, 272),
(303, 300),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str)
@pytest.mark.parametrize("eps", [0], ids=["eps_0"])
@pytest.mark.parametrize(
"return_transpose", [True, False], ids=["quantize_transpose", "quantize_only"]
)
@pytest.mark.parametrize("pow_2_scales", [False], ids=["fp32scales"])
@pytest.mark.parametrize("tile_size", [(1, 128), (128, 128)], ids=["1DTile", "2DTile"])
def test_quantization_block_tiling_versus_reference_fp32_scales(
x_dtype: torch.dtype,
M: int,
N: int,
quant_dtype: torch.dtype,
eps: float,
return_transpose: bool,
pow_2_scales: bool,
tile_size: Tuple[int, int],
) -> None:
check_quantization_block_tiling_versus_reference(
x_dtype, M, N, quant_dtype, eps, return_transpose, pow_2_scales, tile_size
)
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
[
# full tile cases
(128, 128),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str)
@pytest.mark.parametrize("eps", [0], ids=["eps_0"])
@pytest.mark.parametrize("pow_2_scales", [True, False], ids=["pow2scales", "fp32scales"])
@pytest.mark.parametrize("tile_size", [(128, 128)])
@pytest.mark.parametrize("extrema_high", [False, True], ids=["zeros", "maxes"])
def test_quantization_block_tiling_extrema_versus_reference(
......@@ -292,3 +364,130 @@ def test_quantization_block_tiling_extrema_versus_reference(
atol=0.0,
rtol=0.0,
)
# FP8 per tesnor current scaling
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
class TestFP8BlockScalingRecipeLinear(TestFP8RecipeLinearBase):
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
@pytest.mark.parametrize(
"batch_size, hidden_size, out_size",
[
(16, 256, 128),
],
)
@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"])
@pytest.mark.parametrize(
"recipe1, recipe2",
[
(GetRecipes.none, GetRecipes.fp8_blockwise),
],
)
def test_fp8_current_scaling_with_linear_module(
self,
recipe1,
recipe2,
batch_size,
hidden_size,
out_size,
dtype,
use_bias=True,
):
fp8_zero_tolerance_tensor_dumps_recipe2 = None
# check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad
# if we cannot get all four tensors, then still set the tensor dump to None
tensor_map = self._check_golden_tensor_dumps(
TENSOR_DUMP_DIR, recipe2, (batch_size, hidden_size, out_size), dtype, use_bias
)
if tensor_map is not None:
fp8_zero_tolerance_tensor_dumps_recipe2 = tensor_map
self.compare_recipe(
recipe1,
recipe2,
batch_size,
hidden_size,
out_size,
use_bias,
seed=torch.initial_seed(),
dtype=dtype,
y_error=0.5,
dgrad_error=1,
wgrad_error=1,
bgrad_error=0.5,
recipe1_golden_tensors=None,
recipe2_golden_tensors=fp8_zero_tolerance_tensor_dumps_recipe2,
)
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
class TestFP8BlockScalingRecipeLayerNormLinear(TestFP8RecipeLayerNormLinearBase):
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
@pytest.mark.parametrize(
"batch_size, hidden_size, out_size",
[
(16, 256, 128),
],
)
@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"])
@pytest.mark.parametrize(
"recipe1, recipe2",
[
(GetRecipes.none, GetRecipes.fp8_blockwise),
],
)
def test_fp8_current_scaling_with_layernorm_linear_module(
self,
recipe1,
recipe2,
batch_size,
hidden_size,
out_size,
dtype,
use_bias=True,
):
fp8_zero_tolerance_tensor_dumps_recipe2 = None
# check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad
# if we cannot get all four tensors, then still set the tensor dump to None
tensor_map = self._check_golden_tensor_dumps(
TENSOR_DUMP_DIR,
recipe2,
(batch_size, hidden_size, out_size),
dtype,
use_bias,
"LayerNorm",
)
if tensor_map is not None:
fp8_zero_tolerance_tensor_dumps_recipe2 = tensor_map
self.compare_recipe(
recipe1,
recipe2,
batch_size,
hidden_size,
out_size,
use_bias,
seed=torch.initial_seed(),
dtype=dtype,
y_error=0.5,
ln_out_error=0.5,
dgrad_error=1.6,
wgrad_error=1,
bgrad_error=0.5,
recipe1_golden_tensors=None,
recipe2_golden_tensors=fp8_zero_tolerance_tensor_dumps_recipe2,
)
......@@ -82,7 +82,8 @@ class TestFP8RecipeLinearBase:
@staticmethod
def _get_mean_abs_relative_error(a, b):
return torch.mean(torch.abs((a - b) / b))
error = torch.where(b == 0, torch.ne(a, b), torch.abs((a - b) / b))
return torch.mean(error)
@staticmethod
def _load_golden_tensor_values(a, b):
......@@ -97,9 +98,12 @@ class TestFP8RecipeLinearBase:
fp8_type_g = get_fp8_torch_dtype(recipe, fprop_tensor=False)
# Expected tensor names based on the naming template
scaling_type = ( # Assuming the scaling type is PER_TENSOR for this example
"ScalingType.PER_TENSOR"
)
if recipe.float8_current_scaling():
scaling_type = "ScalingType.PER_TENSOR"
elif recipe.float8_block_scaling():
scaling_type = "ScalingType.VECTOR_TILED_X_AND_G_BLOCK_TILED_W"
else:
scaling_type = "Unknown"
current_seed = torch.initial_seed() # Get the current seed
expected_tensor_names = {
......@@ -437,9 +441,13 @@ class TestFP8RecipeLayerNormLinearBase(TestFP8RecipeLinearBase):
fp8_type_g = get_fp8_torch_dtype(recipe, fprop_tensor=False)
# Expected tensor names based on the naming template
scaling_type = ( # Assuming the scaling type is PER_TENSOR for this example
"ScalingType.PER_TENSOR"
)
if recipe.float8_current_scaling():
scaling_type = "ScalingType.PER_TENSOR"
elif recipe.float8_block_scaling():
scaling_type = "ScalingType.VECTOR_TILED_X_AND_G_BLOCK_TILED_W"
else:
scaling_type = "Unknown"
current_seed = torch.initial_seed() # Get the current seed
expected_tensor_names = {
......
......@@ -110,7 +110,10 @@ class TestFloat8BlockwiseTensor:
dims = _to_list(dims)
# Initialize random data
# Note: Make sure values are not all close to zero, or else
# test may pass trivially.
x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
x_ref.view(-1)[0] = 0.75
x_ref_cuda = x_ref.to("cuda")
# Cast to FP8 and back
......@@ -150,6 +153,24 @@ class TestFloat8BlockwiseTensor:
)
self._test_quantize_dequantize(quantizer=quantizer, dtype=dtype, atol=atol, rtol=rtol)
@pytest.mark.parametrize("fp8_dtype", _fp8_dtypes)
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("block_scaling_dim", [1])
def test_quantize_dequantize_columnwise_only(
self, fp8_dtype: tex.DType, dtype: torch.dtype, block_scaling_dim: int
) -> None:
atol = _tols[fp8_dtype]["atol"]
rtol = _tols[fp8_dtype]["rtol"]
quantizer = Float8BlockQuantizer(
fp8_dtype=fp8_dtype,
rowwise=False,
columnwise=True,
block_scaling_dim=block_scaling_dim,
)
self._test_quantize_dequantize(
quantizer=quantizer, dtype=dtype, atol=atol, rtol=rtol, use_cpp_allocation=True
)
@pytest.mark.parametrize(
"dims", [[], 256, 311, [264], [256, 512], [250, 500], [7, 5, 3], [2, 3, 5, 3]]
)
......
......@@ -4,7 +4,7 @@
from collections.abc import Iterable
import io
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, List, Tuple, Union, Optional
import pytest
import torch
......@@ -158,6 +158,32 @@ class TestFloat8Tensor:
def test_quantize_dequantize_dims(self, dims: DimsType) -> None:
self._test_quantize_dequantize(dims=dims)
@pytest.mark.parametrize("fp8_dtype", _fp8_dtypes)
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("noop", [True, False])
def test_quantize_dequantize_noop(
self, fp8_dtype: tex.DType, dtype: torch.dtype, noop: bool
) -> None:
noop_tensor = torch.zeros(1, dtype=torch.float32, device="cuda")
if noop:
noop_tensor = torch.ones(1, dtype=torch.float32, device="cuda")
dims = 23
scale: float = 3.5
# Initialize random data
x_ref = 2 * torch.rand(_to_list(dims), dtype=dtype, device="cpu") - 1
# Cast to FP8 and back
x_fp8 = to_float8(x_ref, fp8_dtype=fp8_dtype, scale=scale)
# if noop, then when we input a different tensor, output should still be x_fp8_orig
x_ref_noop_test = 2 * x_ref.cuda()
x_fp8_orig = x_fp8.clone()
x_fp8.quantize_(x_ref_noop_test, noop_flag=noop_tensor)
if noop_tensor.item() == 1.0:
torch.testing.assert_close(x_fp8, x_fp8_orig, atol=0, rtol=0)
else:
torch.testing.assert_close(x_fp8, x_ref_noop_test, **_tols[fp8_dtype])
def test_basic_ops(
self,
dims: DimsType = 23,
......
......@@ -50,6 +50,9 @@ import transformer_engine_torch as tex
# Only run FP8 tests on supported devices.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
)
sm_80plus = get_device_compute_capability() >= (8, 0)
......@@ -104,6 +107,7 @@ fp8_recipes = [
recipe.MXFP8BlockScaling(),
recipe.DelayedScaling(),
recipe.Float8CurrentScaling(),
recipe.Float8BlockScaling(),
]
......@@ -563,6 +567,8 @@ def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_m
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model]
......@@ -675,6 +681,8 @@ def test_gpt_full_activation_recompute(
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model]
......@@ -1528,6 +1536,8 @@ def test_grouped_linear_accuracy(
pytest.skip("MXFP8 unsupported for grouped linear.")
if fp8 and recipe.float8_current_scaling():
pytest.skip("Float8 Current Scaling unsupported for grouped linear.")
if recipe.float8_block_scaling():
pytest.skip("Grouped linear for FP8 blockwise unsupported.")
config = model_configs[model]
if config.seq_len % 16 != 0 and fp8:
......@@ -1723,6 +1733,8 @@ def test_padding_grouped_linear_accuracy(
pytest.skip("MXFP8 unsupported for grouped linear.")
if fp8 and recipe.float8_current_scaling():
pytest.skip("Float8 Current Scaling unsupported for grouped linear.")
if recipe.float8_block_scaling():
pytest.skip("Float8 block scaling unsupported for grouped linear.")
config = model_configs[model]
if config.seq_len % 16 != 0 and fp8:
......@@ -1933,6 +1945,8 @@ def test_gpt_fp8_parameters(dtype, bs, model, recipe):
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
config = model_configs[model]
......
......@@ -46,6 +46,9 @@ from test_numerics import reset_rng_states, dtype_tols
# Only run FP8 tests on supported devices.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
)
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
......@@ -106,6 +109,7 @@ fp8_recipes = [
None, # Test non-FP8
recipe.MXFP8BlockScaling(), # Test default
recipe.Float8CurrentScaling(), # Test default
recipe.Float8BlockScaling(), # Test default
recipe.DelayedScaling(), # Test default
recipe.DelayedScaling( # Test most_recent algo
amax_history_len=16,
......@@ -439,6 +443,8 @@ def test_sanity_layernorm_linear(
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
......@@ -470,6 +476,8 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad):
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
......@@ -502,6 +510,8 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
......@@ -543,10 +553,14 @@ def test_sanity_grouped_linear(
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8():
pytest.skip("Grouped linear does not support MXFP8")
if fp8_recipe.float8_current_scaling():
pytest.skip("Grouped linear does not support FP8 current scaling")
if fp8_recipe.float8_block_scaling():
pytest.skip("Grouped linear does not support FP8 block scaling")
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
......@@ -590,6 +604,8 @@ def test_sanity_layernorm_mlp(
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
......@@ -640,6 +656,8 @@ def test_sanity_gpt(
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
......@@ -707,6 +725,8 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
......@@ -766,6 +786,8 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, no
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
......@@ -823,6 +845,8 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
......@@ -858,6 +882,8 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad):
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
......@@ -896,6 +922,8 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
......@@ -937,6 +965,8 @@ def test_sanity_gradient_accumulation_fusion(
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if not config.is_fp8_supported():
......@@ -979,8 +1009,12 @@ def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamm
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8_recipe.float8_block_scaling():
pytest.skip("cuda graph not supported for float8_block_scaling recipe")
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
......
......@@ -32,8 +32,8 @@ void act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
constexpr NVTETensor workspace = nullptr;
constexpr const NVTETensor grad = nullptr;
quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, OP>(input, grad, nullptr, output, dbias,
workspace, stream);
quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, OP>(input, grad, output, dbias, workspace,
nullptr, stream);
}
template <typename ComputeType, typename Param, ComputeType (*OP)(ComputeType, const Param &)>
......@@ -46,8 +46,8 @@ void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output,
constexpr NVTETensor dbias = nullptr;
constexpr NVTETensor workspace = nullptr;
quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, OP>(input, grad, nullptr, output, dbias,
workspace, stream);
quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, OP>(input, grad, output, dbias, workspace,
nullptr, stream);
}
template <typename ComputeType, typename Param, ComputeType (*ActOP)(ComputeType, const Param &)>
......
......@@ -233,10 +233,12 @@ struct Tensor {
struct QuantizationConfig {
bool force_pow_2_scales = false;
float amax_epsilon = 0.0f;
NVTETensor noop_tensor = nullptr;
static constexpr size_t attr_sizes[] = {
sizeof(bool), // force_pow_2_scales
sizeof(float) // amax_epsilon
sizeof(float), // amax_epsilon
sizeof(NVTETensor) // noop_tensor
};
};
......
......@@ -89,7 +89,7 @@ extern "C" {
*/
void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Casts input tensor to FP8/MXFP8, providing the option to immediately exit the kernel
/*! \brief Casts input tensor to FP8/MXFP8/BlockwiseFP8, providing the option to immediately exit the kernel
* based on the value of the 'noop' tensor.
* The type of quantized tensor in the output depends on the scaling mode of the output
* tensor. See file level comments.
......@@ -102,6 +102,16 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea
void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop,
cudaStream_t stream);
/*! \brief Casts input tensor to quantized output tensor, with advanced quantization options.
*
* \param[in] input Input tensor to be cast.
* \param[in,out] output Output quantized tensor.
* \param[in] quant_config Quantization configuration.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_quantize_v2(const NVTETensor input, NVTETensor output,
const NVTEQuantizationConfig quant_config, cudaStream_t stream);
/*! \brief Casts input tensor to MXFP8. Additionally, reduces the input along columns.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
......
......@@ -286,6 +286,12 @@ enum NVTEQuantizationConfigAttribute {
kNVTEQuantizationConfigForcePow2Scales = 0,
/*! Small value to add to amax for numerical stability */
kNVTEQuantizationConfigAmaxEpsilon = 1,
/*! Noop tensor (containing a scalar).
If the scalar element value = 1, quantization kernel will early exit.
This is a tensor because the flag must be on GPU in order to enable
conditional early even when captured in a static CUDA graph.
*/
kNVTEQuantizationConfigNoopTensor = 2,
kNVTEQuantizationConfigNumAttributes
};
......@@ -724,6 +730,12 @@ class QuantizationConfigWrapper {
&amax_epsilon, sizeof(float));
}
/*! \brief Set noop tensor pointer */
void set_noop_tensor(NVTETensor noop_tensor) {
nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigNoopTensor, &noop_tensor,
sizeof(NVTETensor));
}
private:
/*! \brief Wrapped NVTEQuantizationConfig. */
NVTEQuantizationConfig config_ = nullptr;
......
......@@ -5,6 +5,7 @@
"""This module provides predefined FP8 recipes."""
from __future__ import annotations
import warnings
import os
from enum import Enum
from typing import Literal, Optional, Union, Callable, NamedTuple
from pydantic.dataclasses import dataclass
......@@ -81,6 +82,10 @@ class Recipe:
"""Whether the given recipe is per-tensor scaling."""
return isinstance(self, (DelayedScaling, Float8CurrentScaling))
def float8_block_scaling(self):
"""Whether the given recipe is float8 blockwise scaling."""
return isinstance(self, Float8BlockScaling)
@dataclass()
class DelayedScaling(Recipe):
......@@ -287,3 +292,99 @@ class MXFP8BlockScaling(Recipe):
def __repr__(self) -> str:
return f"margin={self.margin}, format={str(self.fp8_format).split('.')[1]},"
@dataclass()
class Float8BlockScaling(Recipe):
"""
Use block-wise scaling for FP8 tensors.
In this strategy, tensors are scaled in blockwise fashion. Values within
each block share a common scaling factor. The block dimensionality
can be configured. The scaling factors are float32 containers. They
will by default be constrained to powers of 2.
Since the scaling happens in a particular direction (either rowwise
or columnwise), the quantized tensor and its transpose are not numerically
equivalent. Due to this, when Transformer Engine needs both the FP8 tensor
and its transpose (e.g. to calculate both forward and backward pass),
during the quantization both versions are computed from the high precision
input to avoid double quantization errors.
NOTE: To relax the default constraint that scales be powers of 2, set env variable
NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1 to override it for the recipe defaults.
export NVTE_FP8_BLOCK_SCALING_FP32_SCALES=1
Or initialize the Recipe with non-default QParams in code for increased control.
Parameters
----------
fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.E4M3
Controls the FP8 data format used during forward and backward
pass.
fp8_quant_fwd_inp: QParams, default QParams{power_2_scale=True, amax_epsilon=0.0}
used for quantization of input tensor x
fp8_quant_fwd_weight: QParams, default QParams{power_2_scale=True, amax_epsilon=0.0}
used for quantization of weight tensor w
fp8_quant_bwd_grad: QParams, default QParams{power_2_scale=True, amax_epsilon=0.0}
used for quantization of gradient tensor dY
x_block_scaling_dim: Choice to use 1x128 (1 dimensional) or 128x128 (2 dimensional)
qblock scaling for x.
w_block_scaling_dim: Choice to use 1x128 (1 dimensional) or 128x128 (2 dimensional)
qblock scaling for w.
grad_block_scaling_dim: Choice to use 1x128 (1 dimensional) or 128x128 (2 dimensional)
qblock scaling for grad.
fp8_gemm_fprop: MMParams, default MMParams.use_split_accumulator=False
used for calculating output y in forward pass
fp8_gemm_dgrad: MMParams, default MMParams.use_split_accumulator=True
use for calculating dgrad in backward pass
fp8_gemm_wgrad: MMParams, default MMParams.use_split_accumulator=True
use for calculating dgrad in backward pass
"""
use_f32_scales: bool = os.getenv("NVTE_FP8_BLOCK_SCALING_FP32_SCALES", "0") == "1"
fp8_format: Format = Format.E4M3
fp8_quant_fwd_inp = QParams(power_2_scale=not use_f32_scales, amax_epsilon=0.0)
fp8_quant_fwd_weight = QParams(power_2_scale=not use_f32_scales, amax_epsilon=0.0)
fp8_quant_bwd_grad = QParams(power_2_scale=not use_f32_scales, amax_epsilon=0.0)
x_block_scaling_dim: int = 1
w_block_scaling_dim: int = 2
grad_block_scaling_dim: int = 1
fp8_gemm_fprop: MMParams = MMParams(use_split_accumulator=True)
fp8_gemm_dgrad: MMParams = MMParams(use_split_accumulator=True)
fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True)
fp8_dpa: bool = False
fp8_mha: bool = False
def __post_init__(self) -> None:
assert self.x_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for x"
assert self.w_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for w"
assert self.grad_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for grad"
assert not (
self.x_block_scaling_dim == 2 and self.w_block_scaling_dim == 2
), "2D by 2D block gemm not supported."
assert not (
self.x_block_scaling_dim == 2 and self.grad_block_scaling_dim == 2
), "2D by 2D block gemm not supported."
assert not (
self.w_block_scaling_dim == 2 and self.grad_block_scaling_dim == 2
), "2D by 2D block gemm not supported."
assert self.fp8_gemm_fprop.use_split_accumulator, "Split accumulator required for fprop."
assert self.fp8_gemm_dgrad.use_split_accumulator, "Split accumulator required for dgrad."
assert self.fp8_gemm_wgrad.use_split_accumulator, "Split accumulator required for wgrad."
def __repr__(self) -> str:
return (
f"format={str(self.fp8_format).split('.')[1]}, "
f"fp8_quant_fwd_inp={self.fp8_quant_fwd_inp}, "
f"fp8_quant_fwd_weight={self.fp8_quant_fwd_weight}, "
f"fp8_quant_bwd_grad={self.fp8_quant_bwd_grad}, "
f"x_block_scaling_dim={self.x_block_scaling_dim}, "
f"w_block_scaling_dim={self.w_block_scaling_dim}, "
f"grad_block_scaling_dim={self.grad_block_scaling_dim}, "
f"fp8_gemm_fprop={self.fp8_gemm_fprop}, "
f"fp8_gemm_dgrad={self.fp8_gemm_dgrad}, "
f"fp8_gemm_wgrad={self.fp8_gemm_wgrad}, "
f"fp8_dpa={self.fp8_dpa}, "
f"fp8_mha={self.fp8_mha}"
)
......@@ -429,6 +429,9 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config,
case kNVTEQuantizationConfigAmaxEpsilon:
std::memcpy(buf, &config_.amax_epsilon, attr_size);
break;
case kNVTEQuantizationConfigNoopTensor:
std::memcpy(buf, &config_.noop_tensor, attr_size);
break;
default:
NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
}
......@@ -458,6 +461,9 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config,
case kNVTEQuantizationConfigAmaxEpsilon:
std::memcpy(&config_.amax_epsilon, buf, attr_size);
break;
case kNVTEQuantizationConfigNoopTensor:
std::memcpy(&config_.noop_tensor, buf, attr_size);
break;
default:
NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
}
......
......@@ -29,11 +29,35 @@ void quantize_transpose_square_blockwise(const SimpleTensor &input, SimpleTensor
const bool return_transpose, const bool pow_2_scale,
cudaStream_t stream);
// enum class for rowwise usage
enum class FP8BlockwiseRowwiseOption {
// No rowwise data
NONE,
// Rowwise data, scales in GEMM format
ROWWISE
// TODO: FP8 all gather requires some changes.
// 1. Compact scales are better for gathering than the GEMM format.
};
// enum class for columnwise usage
// For Hopper sm90 with only TN fp8 gemm, there is need to do columnwise transpose when doing 1D block scaling
enum class FP8BlockwiseColumnwiseOption {
// No columnwise data
NONE,
// Columnwise data transposed from original shape.
// Scales in GEMM format corresponding to GEMM ingesting transposed column data.
COLUMNWISE_TRANSPOSE
// TODO: FP8 all gather requires some changes.
// 1. The transpose gets in the way of the all gather.
// 2. Compact scales are better for gathering than the GEMM format.
};
void quantize_transpose_vector_blockwise(const SimpleTensor &input, SimpleTensor &scale_inv,
SimpleTensor &scale_inv_t, SimpleTensor &output,
SimpleTensor &output_t, const float epsilon,
const bool return_transpose, const bool pow_2_scale,
cudaStream_t stream);
FP8BlockwiseRowwiseOption rowwise_option,
FP8BlockwiseColumnwiseOption columnwise_option,
const bool pow_2_scale, cudaStream_t stream);
} // namespace transformer_engine::detail
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment