Unverified Commit 9985b02c authored by Zhongbo Zhu's avatar Zhongbo Zhu Committed by GitHub
Browse files

[PyTorch] FP8 Subchannel Recipe With FP8 Gather And Configurable Scaling...


[PyTorch] FP8 Subchannel Recipe With FP8 Gather And Configurable Scaling Factor Tensor Swizzling (#1707)

* functional kernel for columnwise + no-transpose option, still hacky
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* pass all quantizer unit tests
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* refactor, add gemm ready api
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* make format options private members, simplify api
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* swizzle scales right before gemm
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* bug fix of single layer test
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* attempt to fix lint issue
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fp8 gather pass, need minor refine
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix return_layernorm_output_gathered case
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* remove special cases, add sanity check before gemm
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* fix lint
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* lint ungrouped imports
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Implement dequantize for compact 1D blocks.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* add more unit test with dequantize compact supported
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* lint again
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* make ag for subchannel respect async
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* zero tolerance in distributed test
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* fix zero tolerance test
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* resolve rebase issues
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* lint & format
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* fix lint
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* clean up
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* bug fix
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* relax rtol for fp32 distributed test
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* fix some ci issue
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* fix ci test failure in debug mode
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* Force row-wise and column-wise data to have same data format

Prototype "all-gather usage" in quantizer.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove dead logic for high-precision AGs
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Debug FP8 block-wise tensor tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug distributed test
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Handle case where LayerNormLinear returns gathered norm output
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* fix debug mode
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

---------
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarKeith Wyss <kwyss@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 6123d7e0
...@@ -100,11 +100,15 @@ def main(argv=None, namespace=None): ...@@ -100,11 +100,15 @@ def main(argv=None, namespace=None):
# Quantization scheme # Quantization scheme
QUANTIZATION = args.quantization QUANTIZATION = args.quantization
if QUANTIZATION in ("fp8", "mxfp8", "fp8_block_scaling"): global SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE
global SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE if QUANTIZATION in ("fp8", "mxfp8"):
SEQ_LEN = 32 SEQ_LEN = 32
BATCH_SIZE = 32 BATCH_SIZE = 32
HIDDEN_SIZE = 128 HIDDEN_SIZE = 128
elif QUANTIZATION == "fp8_block_scaling":
SEQ_LEN = 128
BATCH_SIZE = 128
HIDDEN_SIZE = 512
test_dict = [ test_dict = [
test_quantizer, test_quantizer,
...@@ -185,7 +189,7 @@ def _get_tolerances(dtype): ...@@ -185,7 +189,7 @@ def _get_tolerances(dtype):
if dtype == torch.bfloat16: if dtype == torch.bfloat16:
return {"rtol": 1.6e-2, "atol": 1e-5} return {"rtol": 1.6e-2, "atol": 1e-5}
if dtype == torch.float32: if dtype == torch.float32:
return {"rtol": 1e-4, "atol": 1e-4} return {"rtol": 1.2e-4, "atol": 1e-4}
raise ValueError(f"Unsupported dtype ({dtype})") raise ValueError(f"Unsupported dtype ({dtype})")
...@@ -649,7 +653,7 @@ def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs ...@@ -649,7 +653,7 @@ def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs
if "return_layernorm_output" in kwargs: if "return_layernorm_output" in kwargs:
output_single_node, norm_s = output_single_node output_single_node, norm_s = output_single_node
output_distributed, norm_d = output_distributed output_distributed, norm_d = output_distributed
if sequence_parallel: if sequence_parallel and not kwargs.get("return_layernorm_output_gathered", False):
norm_d = _gather(norm_d) norm_d = _gather(norm_d)
_check_outputs(norm_s, norm_d) _check_outputs(norm_s, norm_d)
...@@ -758,7 +762,7 @@ def _test_layernorm_mlp(set_parallel_mode=None, sequence_parallel=False, **kwarg ...@@ -758,7 +762,7 @@ def _test_layernorm_mlp(set_parallel_mode=None, sequence_parallel=False, **kwarg
if "return_layernorm_output" in kwargs: if "return_layernorm_output" in kwargs:
output_single_node, norm_s = output_single_node output_single_node, norm_s = output_single_node
output_distributed, norm_d = output_distributed output_distributed, norm_d = output_distributed
if sequence_parallel: if sequence_parallel and not kwargs.get("return_layernorm_output_gathered", False):
norm_d = _gather(norm_d) norm_d = _gather(norm_d)
_check_outputs(norm_s, norm_d) _check_outputs(norm_s, norm_d)
......
...@@ -260,6 +260,7 @@ class BlockwiseQuantizerReference: ...@@ -260,6 +260,7 @@ class BlockwiseQuantizerReference:
eps: float = 0.0, eps: float = 0.0,
pow_2_scales: bool = False, pow_2_scales: bool = False,
quant_tile_shape: Tuple[int, int] = (128, 128), quant_tile_shape: Tuple[int, int] = (128, 128),
munge_scale_shapes: bool = True,
) -> QuantizeResult: ) -> QuantizeResult:
# sanity checks # sanity checks
assert x.dim() == 2 assert x.dim() == 2
...@@ -277,27 +278,33 @@ class BlockwiseQuantizerReference: ...@@ -277,27 +278,33 @@ class BlockwiseQuantizerReference:
assert quant_tile_shape in ((1, 128), (128, 128)) assert quant_tile_shape in ((1, 128), (128, 128))
if quant_tile_shape[0] == 1: if quant_tile_shape[0] == 1:
# Quantize row-wise # Quantize row-wise
return self.scale_munger.munge_scale_shapes_for_backend( result = self._quantize_vector_tiling(
self._quantize_vector_tiling( x,
x, quant_dtype,
quant_dtype, tile_len=quant_tile_shape[1],
tile_len=quant_tile_shape[1], return_transpose=return_transpose,
return_transpose=return_transpose, pow_2_scales=pow_2_scales,
pow_2_scales=pow_2_scales, eps=eps,
eps=eps,
),
quant_tile_shape,
) )
if munge_scale_shapes:
result = self.scale_munger.munge_scale_shapes_for_backend(
result,
quant_tile_shape,
)
return result
else: else:
# Quantize block-wise # Quantize block-wise
return self.scale_munger.munge_scale_shapes_for_backend( result = self._quantize_square_block_tiling(
self._quantize_square_block_tiling( x,
x, quant_dtype,
quant_dtype, tile_len=quant_tile_shape[0],
tile_len=quant_tile_shape[0], return_transpose=return_transpose,
return_transpose=return_transpose, pow_2_scales=pow_2_scales,
pow_2_scales=pow_2_scales, eps=eps,
eps=eps,
),
quant_tile_shape,
) )
if munge_scale_shapes:
result = self.scale_munger.munge_scale_shapes_for_backend(
result,
quant_tile_shape,
)
return result
...@@ -88,6 +88,126 @@ def initialize_for_many_scales( ...@@ -88,6 +88,126 @@ def initialize_for_many_scales(
return result return result
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
[
# full tile cases
(128, 128),
(256, 256),
(256, 1024),
(1024, 256),
# Padding required cases
(256, 272),
(303, 300),
(305, 256),
# Some larger tiles.
(2000, 2000),
(2048, 2000),
(2000, 1024),
(2048, 1024),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str)
@pytest.mark.parametrize("eps", [0], ids=["eps_0"])
@pytest.mark.parametrize("pow_2_scales", [True], ids=["pow2scales"])
def test_quantization_1D_block_tiling_with_compact_data_and_scales(
x_dtype: torch.dtype,
M: int,
N: int,
quant_dtype: torch.dtype,
eps: float,
pow_2_scales: bool,
) -> None:
te_dtype = TE_DType[quant_dtype]
tile_size = (1, 128)
# This test runs a comparison of the ref class versus the class using
# CUDA kernels to quantize. They should quantize identically for pixels
# that are not DC values in the scale factor shape.
ref_quantizer = BlockwiseQuantizerReference()
sut_quantizer = Float8BlockQuantizer(
fp8_dtype=te_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=eps,
force_pow_2_scales=pow_2_scales,
block_scaling_dim=1,
all_gather_usage=True,
)
# Setup device and random seed
device = "cuda"
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Input
x = initialize_for_many_scales((M, N), tile_size, dtype=x_dtype, device=device)
x_fp8_sut = sut_quantizer.make_empty((M, N), dtype=x_dtype, device=device, requires_grad=False)
x_fp8_sut = sut_quantizer.update_quantized(x, x_fp8_sut)
x_fp8_sut_cpp_alloc = sut_quantizer(x)
assert x_fp8_sut._rowwise_data is not None
qx: torch.Tensor = x_fp8_sut._rowwise_data.view(dtype=quant_dtype)
assert x_fp8_sut._rowwise_scale_inv is not None
sx: torch.Tensor = x_fp8_sut._rowwise_scale_inv
qx_t = x_fp8_sut._columnwise_data
sx_t = x_fp8_sut._columnwise_scale_inv
qresult_ref = ref_quantizer.quantize(
x,
quant_dtype=quant_dtype,
return_transpose=True,
eps=eps,
pow_2_scales=pow_2_scales,
quant_tile_shape=tile_size,
munge_scale_shapes=False,
)
qx_ref, sx_ref, qx_t_ref, sx_t_ref = (
qresult_ref.data,
qresult_ref.scale,
qresult_ref.data_t,
qresult_ref.scale_t,
)
# match the reference quantize transpose output with the columnwise non-transpose method
qx_t_ref = qx_t_ref.transpose(-1, -2).contiguous()
sx_t_ref = sx_t_ref.transpose(-1, -2).contiguous()
# Check
torch.testing.assert_close(qx.float(), qx_ref.float(), atol=0.0, rtol=0.0)
torch.testing.assert_close(sx, sx_ref, atol=0.0, rtol=0.0)
assert qx_t is not None
qx_t = qx_t.view(dtype=quant_dtype)
assert qx_t_ref is not None
assert sx_t is not None
assert sx_t_ref is not None
torch.testing.assert_close(qx_t.float(), qx_t_ref.float(), atol=0.0, rtol=0.0)
torch.testing.assert_close(sx_t, sx_t_ref, atol=0.0, rtol=0.0)
# check that the C++ and Python allocators are equivalent
torch.testing.assert_close(
x_fp8_sut._rowwise_data, x_fp8_sut_cpp_alloc._rowwise_data, atol=0.0, rtol=0.0
)
torch.testing.assert_close(
x_fp8_sut._rowwise_scale_inv, x_fp8_sut_cpp_alloc._rowwise_scale_inv, atol=0.0, rtol=0.0
)
torch.testing.assert_close(
x_fp8_sut._columnwise_data, x_fp8_sut_cpp_alloc._columnwise_data, atol=0.0, rtol=0.0
)
torch.testing.assert_close(
x_fp8_sut._columnwise_scale_inv,
x_fp8_sut_cpp_alloc._columnwise_scale_inv,
atol=0.0,
rtol=0.0,
)
# check if the fp8 output between C++ and Python are the same
assert x_fp8_sut._data_format == x_fp8_sut_cpp_alloc._data_format
def check_quantization_block_tiling_versus_reference( def check_quantization_block_tiling_versus_reference(
x_dtype: torch.dtype, x_dtype: torch.dtype,
M: int, M: int,
......
...@@ -385,7 +385,7 @@ class TestFP8RecipeLinearBase: ...@@ -385,7 +385,7 @@ class TestFP8RecipeLinearBase:
) )
# recipe1 # recipe1
using_fp8_recipe = recipe1 != GetRecipes.none using_fp8_recipe = recipe1() != GetRecipes.none()
if using_fp8_recipe: if using_fp8_recipe:
with fp8_autocast(enabled=True, fp8_recipe=recipe1()): with fp8_autocast(enabled=True, fp8_recipe=recipe1()):
y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_linear(x, w, bias, gradient) y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_linear(x, w, bias, gradient)
...@@ -393,7 +393,7 @@ class TestFP8RecipeLinearBase: ...@@ -393,7 +393,7 @@ class TestFP8RecipeLinearBase:
y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_linear(x, w, bias, gradient) y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_linear(x, w, bias, gradient)
# recipe2 # recipe2
using_fp8_recipe = recipe2 != GetRecipes.none using_fp8_recipe = recipe2() != GetRecipes.none()
if using_fp8_recipe: if using_fp8_recipe:
with fp8_autocast(enabled=True, fp8_recipe=recipe2()): with fp8_autocast(enabled=True, fp8_recipe=recipe2()):
y_q, dgrad, wgrad, bgrad = self.run_linear(x, w, bias, gradient) y_q, dgrad, wgrad, bgrad = self.run_linear(x, w, bias, gradient)
...@@ -608,7 +608,7 @@ class TestFP8RecipeLayerNormLinearBase(TestFP8RecipeLinearBase): ...@@ -608,7 +608,7 @@ class TestFP8RecipeLayerNormLinearBase(TestFP8RecipeLinearBase):
) )
# recipe1 # recipe1
using_fp8_recipe = recipe1 != GetRecipes.none using_fp8_recipe = recipe1() != GetRecipes.none()
if using_fp8_recipe: if using_fp8_recipe:
with fp8_autocast(enabled=True, fp8_recipe=recipe1()): with fp8_autocast(enabled=True, fp8_recipe=recipe1()):
y_q_ref, ln_out_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_layernorm_linear( y_q_ref, ln_out_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_layernorm_linear(
...@@ -630,7 +630,7 @@ class TestFP8RecipeLayerNormLinearBase(TestFP8RecipeLinearBase): ...@@ -630,7 +630,7 @@ class TestFP8RecipeLayerNormLinearBase(TestFP8RecipeLinearBase):
) )
# recipe2 # recipe2
using_fp8_recipe = recipe2 != GetRecipes.none using_fp8_recipe = recipe2() != GetRecipes.none()
if using_fp8_recipe: if using_fp8_recipe:
with fp8_autocast(enabled=True, fp8_recipe=recipe2()): with fp8_autocast(enabled=True, fp8_recipe=recipe2()):
y_q, ln_out, dgrad, wgrad, bgrad = self.run_layernorm_linear( y_q, ln_out, dgrad, wgrad, bgrad = self.run_layernorm_linear(
......
...@@ -176,7 +176,40 @@ class TestFloat8BlockwiseTensor: ...@@ -176,7 +176,40 @@ class TestFloat8BlockwiseTensor:
) )
@pytest.mark.parametrize("block_scaling_dim", [1, 2]) @pytest.mark.parametrize("block_scaling_dim", [1, 2])
@pytest.mark.parametrize("dq_columnwise", [True, False]) @pytest.mark.parametrize("dq_columnwise", [True, False])
@pytest.mark.parametrize("all_gather_usage", [True, False])
def test_quantize_dequantize_dims( def test_quantize_dequantize_dims(
self,
dims: DimsType,
block_scaling_dim: int,
dq_columnwise: bool,
all_gather_usage: bool,
) -> None:
if all_gather_usage and block_scaling_dim != 1:
pytest.skip("all_gather_usage only implemented for 1D block quantization.")
atol = _tols[tex.DType.kFloat8E4M3]["atol"]
rtol = _tols[tex.DType.kFloat8E4M3]["rtol"]
quantizer = Float8BlockQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
rowwise=True,
columnwise=dq_columnwise,
block_scaling_dim=block_scaling_dim,
all_gather_usage=all_gather_usage,
)
self._test_quantize_dequantize(
quantizer=quantizer,
dims=dims,
atol=atol,
rtol=rtol,
dequant_columnwise=dq_columnwise,
)
@pytest.mark.parametrize(
"dims", [[], 256, 311, [264], [256, 512], [250, 500], [7, 5, 3], [2, 3, 5, 3]]
)
@pytest.mark.parametrize("block_scaling_dim", [1, 2])
@pytest.mark.parametrize("dq_columnwise", [True, False])
@pytest.mark.xfail(raises=NotImplementedError)
def test_quantize_dequantize_compact_format(
self, dims: DimsType, block_scaling_dim: int, dq_columnwise: bool self, dims: DimsType, block_scaling_dim: int, dq_columnwise: bool
) -> None: ) -> None:
atol = _tols[tex.DType.kFloat8E4M3]["atol"] atol = _tols[tex.DType.kFloat8E4M3]["atol"]
...@@ -186,6 +219,7 @@ class TestFloat8BlockwiseTensor: ...@@ -186,6 +219,7 @@ class TestFloat8BlockwiseTensor:
rowwise=True, rowwise=True,
columnwise=dq_columnwise, columnwise=dq_columnwise,
block_scaling_dim=block_scaling_dim, block_scaling_dim=block_scaling_dim,
all_gather_usage=True,
) )
self._test_quantize_dequantize( self._test_quantize_dequantize(
quantizer=quantizer, quantizer=quantizer,
...@@ -250,8 +284,13 @@ class TestFloat8BlockwiseTensor: ...@@ -250,8 +284,13 @@ class TestFloat8BlockwiseTensor:
@pytest.mark.parametrize("dims", [[256, 512], [250, 500]]) @pytest.mark.parametrize("dims", [[256, 512], [250, 500]])
@pytest.mark.parametrize("block_scaling_dim", [1, 2]) @pytest.mark.parametrize("block_scaling_dim", [1, 2])
def test_serialization(self, dims: DimsType, block_scaling_dim: int) -> None: @pytest.mark.parametrize("all_gather_usage", [True, False])
def test_serialization(
self, dims: DimsType, block_scaling_dim: int, all_gather_usage: bool
) -> None:
"""Test serialization of Float8BlockwiseQTensor""" """Test serialization of Float8BlockwiseQTensor"""
if all_gather_usage and block_scaling_dim != 1:
pytest.skip("all_gather_usage only implemented for 1D block quantization.")
device = "cuda" device = "cuda"
dtype = torch.bfloat16 dtype = torch.bfloat16
x_hp = torch.rand(_to_list(dims), dtype=dtype, device=device) x_hp = torch.rand(_to_list(dims), dtype=dtype, device=device)
...@@ -260,6 +299,7 @@ class TestFloat8BlockwiseTensor: ...@@ -260,6 +299,7 @@ class TestFloat8BlockwiseTensor:
rowwise=True, rowwise=True,
columnwise=True, columnwise=True,
block_scaling_dim=block_scaling_dim, block_scaling_dim=block_scaling_dim,
all_gather_usage=all_gather_usage,
) )
# Create FP8 tensor # Create FP8 tensor
...@@ -283,6 +323,7 @@ class TestFloat8BlockwiseTensor: ...@@ -283,6 +323,7 @@ class TestFloat8BlockwiseTensor:
assert x_fp8_loaded._is_2D_scaled == x_fp8._is_2D_scaled assert x_fp8_loaded._is_2D_scaled == x_fp8._is_2D_scaled
assert x_fp8_loaded.dtype == x_fp8.dtype assert x_fp8_loaded.dtype == x_fp8.dtype
assert x_fp8_loaded._fp8_dtype == x_fp8._fp8_dtype assert x_fp8_loaded._fp8_dtype == x_fp8._fp8_dtype
assert x_fp8_loaded._data_format == x_fp8._data_format
# Test that dequantized values match # Test that dequantized values match
x_fp8_dequant = x_fp8.dequantize() x_fp8_dequant = x_fp8.dequantize()
......
...@@ -252,11 +252,14 @@ struct QuantizationConfig { ...@@ -252,11 +252,14 @@ struct QuantizationConfig {
bool force_pow_2_scales = false; bool force_pow_2_scales = false;
float amax_epsilon = 0.0f; float amax_epsilon = 0.0f;
NVTETensor noop_tensor = nullptr; NVTETensor noop_tensor = nullptr;
Float8BlockScaleTensorFormat float8_block_scale_tensor_format =
Float8BlockScaleTensorFormat::GEMM_READY;
static constexpr size_t attr_sizes[] = { static constexpr size_t attr_sizes[] = {
sizeof(bool), // force_pow_2_scales sizeof(bool), // force_pow_2_scales
sizeof(float), // amax_epsilon sizeof(float), // amax_epsilon
sizeof(NVTETensor) // noop_tensor sizeof(NVTETensor), // noop_tensor
sizeof(Float8BlockScaleTensorFormat) // float8_block_scale_tensor_format
}; };
}; };
......
...@@ -302,6 +302,13 @@ enum NVTEQuantizationConfigAttribute { ...@@ -302,6 +302,13 @@ enum NVTEQuantizationConfigAttribute {
conditional early even when captured in a static CUDA graph. conditional early even when captured in a static CUDA graph.
*/ */
kNVTEQuantizationConfigNoopTensor = 2, kNVTEQuantizationConfigNoopTensor = 2,
/*! Data format for an FP8 block-scaled tensor
*
* This is not the right design since the tensor format is a
* property of the tensor, not the quantization. This enum will
* likely be refactored away in the future.
*/
kNVTEQuantizationConfigFloat8BlockScaleTensorFormat = 3,
kNVTEQuantizationConfigNumAttributes kNVTEQuantizationConfigNumAttributes
}; };
...@@ -721,6 +728,16 @@ class TensorWrapper { ...@@ -721,6 +728,16 @@ class TensorWrapper {
NVTETensor tensor_ = nullptr; NVTETensor tensor_ = nullptr;
}; };
/*! \enum Float8BlockScaleTensorFormat
* \brief Data format for an FP8 block-scaled tensor
*/
enum class Float8BlockScaleTensorFormat {
/*! FP8 data is transposed if needed and scales are swizzled */
GEMM_READY = 0,
/*! FP8 data is untransposed and scales are not swizzled or padded */
COMPACT = 1
};
/*! \struct QuantizationConfigWrapper /*! \struct QuantizationConfigWrapper
* \brief C++ wrapper for NVTEQuantizationConfigWrapper. * \brief C++ wrapper for NVTEQuantizationConfigWrapper.
*/ */
...@@ -774,6 +791,13 @@ class QuantizationConfigWrapper { ...@@ -774,6 +791,13 @@ class QuantizationConfigWrapper {
sizeof(NVTETensor)); sizeof(NVTETensor));
} }
/*! \brief Set FP8 block-scaled tensor format */
void set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat format) {
nvte_set_quantization_config_attribute(config_,
kNVTEQuantizationConfigFloat8BlockScaleTensorFormat,
&format, sizeof(Float8BlockScaleTensorFormat));
}
private: private:
/*! \brief Wrapped NVTEQuantizationConfig. */ /*! \brief Wrapped NVTEQuantizationConfig. */
NVTEQuantizationConfig config_ = nullptr; NVTEQuantizationConfig config_ = nullptr;
......
...@@ -562,6 +562,9 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config, ...@@ -562,6 +562,9 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config,
case kNVTEQuantizationConfigNoopTensor: case kNVTEQuantizationConfigNoopTensor:
std::memcpy(buf, &config_.noop_tensor, attr_size); std::memcpy(buf, &config_.noop_tensor, attr_size);
break; break;
case kNVTEQuantizationConfigFloat8BlockScaleTensorFormat:
std::memcpy(buf, &config_.float8_block_scale_tensor_format, attr_size);
break;
default: default:
NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")"); NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
} }
...@@ -594,6 +597,9 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, ...@@ -594,6 +597,9 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config,
case kNVTEQuantizationConfigNoopTensor: case kNVTEQuantizationConfigNoopTensor:
std::memcpy(&config_.noop_tensor, buf, attr_size); std::memcpy(&config_.noop_tensor, buf, attr_size);
break; break;
case kNVTEQuantizationConfigFloat8BlockScaleTensorFormat:
std::memcpy(&config_.float8_block_scale_tensor_format, buf, attr_size);
break;
default: default:
NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")"); NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
} }
......
...@@ -31,25 +31,27 @@ void quantize_transpose_square_blockwise(const SimpleTensor &input, SimpleTensor ...@@ -31,25 +31,27 @@ void quantize_transpose_square_blockwise(const SimpleTensor &input, SimpleTensor
// enum class for rowwise usage // enum class for rowwise usage
enum class FP8BlockwiseRowwiseOption { enum class FP8BlockwiseRowwiseOption {
// No rowwise data // No rowwise data, skip rowwise quantization
NONE, NONE,
// Rowwise data, scales in GEMM format // Rowwise data, scales in GEMM format
ROWWISE ROWWISE_GEMM_READY,
// TODO: FP8 all gather requires some changes. // Rowwise data, scales in compact format, needs extra processing (padding, transposing) before GEMM
// 1. Compact scales are better for gathering than the GEMM format. ROWWISE_COMPACT
}; };
// enum class for columnwise usage // enum class for columnwise usage
// For Hopper sm90 with only TN fp8 gemm, there is need to do columnwise transpose when doing 1D block scaling // For Hopper sm90 with only TN fp8 gemm, there is need to do columnwise transpose when doing 1D block scaling
enum class FP8BlockwiseColumnwiseOption { enum class FP8BlockwiseColumnwiseOption {
// No columnwise data // No columnwise data, skip columnwise quantization
NONE, NONE,
// Columnwise data transposed from original shape. // Columnwise data transposed from original shape.
// Scales in GEMM format corresponding to GEMM ingesting transposed column data. // Scales in GEMM format corresponding to GEMM ingesting transposed column data.
COLUMNWISE_TRANSPOSE // On Hopper sm90, GEMM_READY means that columnwise quantization also fuses transpose op
// TODO: FP8 all gather requires some changes. // On higher sm versions with TN,NT,NN fp8 gemm, GEMM_READY doesn't fuse transpose
// 1. The transpose gets in the way of the all gather. COLUMNWISE_GEMM_READY,
// 2. Compact scales are better for gathering than the GEMM format. // Columnwise data in original shape
// Scales in compact format, needs extra processing (padding, transposing) before GEMM
COLUMNWISE_COMPACT
}; };
void quantize_transpose_vector_blockwise(const SimpleTensor &input, SimpleTensor &scale_inv, void quantize_transpose_vector_blockwise(const SimpleTensor &input, SimpleTensor &scale_inv,
......
...@@ -99,14 +99,14 @@ Step 2: Cast and store to output_c ...@@ -99,14 +99,14 @@ Step 2: Cast and store to output_c
| ... | | ... |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +-------------------------------+-------------------------------+-------------------------------+-------------------------------+
Step 3: Transpose, cast and store to output_t Step 3 (if columnwise transpose is True, GEMM_READY): Transpose, cast and store to output_t
* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding) * shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding)
* 8 warps * 8 warps
* Loop 2 times * Loop 2 times
* What each thread does in each loop: * What each thread does in each loop:
* 2 elements (in a row) are read from the shared memory at a time, for a total of 16 times * 2 elements (in a row) are read from the shared memory at a time, for a total of 16 times
* Every 8 consecutive threads do reduction and calculate the amax of each column * Every 8 consecutive threads do reduction and calculate the amax of each column
* 16 elements are quantized and write to output_c at a time, for a total of 2 times * 16 elements are quantized and write to output_t at a time, for a total of 2 times
+------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+ +------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+
| T0 | T8 | T16 | T24 | | | | T0 | T8 | T16 | T24 | | | | | T0 | T8 | T16 | T24 | | | | T0 | T8 | T16 | T24 | | | |
| T1 | T9 | T17 | T25 | | | | T1 | T9 | T17 | T25 | | | | | T1 | T9 | T17 | T25 | | | | T1 | T9 | T17 | T25 | | | |
...@@ -118,6 +118,29 @@ Step 3: Transpose, cast and store to output_t ...@@ -118,6 +118,29 @@ Step 3: Transpose, cast and store to output_t
| T7 | T15 | T23 | T31 | | | | T7 | T15 | T23 | T31 | | | | | T7 | T15 | T23 | T31 | | | | T7 | T15 | T23 | T31 | | | |
+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+ +-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+
Step 3 (if columnwise transpose is False, COMPACT format): Skip Transpose, cast and store to output_t
* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding)
* 8 warps
* Loop 1 times
* What each thread does in each loop:
* 16 elements (in a row) are read from the shared memory, for a total of 4 rows,
* it needs 8 reads in smem to get 16 elements in a row, thread tile shape is 16x4
* Every 32 consecutive threads in a warp do reduction and calculate the amax of each column,
* so each thread will do warp shuffle 16 times to get the amax of each column
* 16 elements are quantized and write to output_t at a time, for a total of 4 times
+------16 elements-------+------16 elements-------+-----80 elements-----+------16 elements------+
| T0 | | | |
| T1 | | | |
| T2 | | | |
| T3 | | | |
| T4 | | | |
| T5 | | | |
| T6 | | | |
| T7 | | | |
| ... | | | |
| T31 | | | |
+-----------------------+-----------------------+-----------------------+-----------------------+
*/ */
// clang-format on // clang-format on
...@@ -140,6 +163,7 @@ constexpr int kNumThreadsLoad = kTileDim / kNVecIn; ...@@ -140,6 +163,7 @@ constexpr int kNumThreadsLoad = kTileDim / kNVecIn;
constexpr int kNumThreadsStore = kTileDim / kNVecOut; constexpr int kNumThreadsStore = kTileDim / kNVecOut;
static_assert(kNumThreadsLoad <= kThreadsPerWarp, "kNumThreadsLoad must be <= kThreadsPerWarp"); static_assert(kNumThreadsLoad <= kThreadsPerWarp, "kNumThreadsLoad must be <= kThreadsPerWarp");
static_assert(kNumThreadsStore <= kThreadsPerWarp, "kNumThreadsStore must be <= kThreadsPerWarp"); static_assert(kNumThreadsStore <= kThreadsPerWarp, "kNumThreadsStore must be <= kThreadsPerWarp");
constexpr int kNumWarps = kThreadsPerBlock / kThreadsPerWarp;
template <bool kAligned, typename CType, typename IType, typename OType> template <bool kAligned, typename CType, typename IType, typename OType>
__global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpose_kernel( __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpose_kernel(
...@@ -149,9 +173,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -149,9 +173,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon, const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon,
FP8BlockwiseRowwiseOption rowwise_option, FP8BlockwiseColumnwiseOption columnwise_option, FP8BlockwiseRowwiseOption rowwise_option, FP8BlockwiseColumnwiseOption columnwise_option,
const bool pow_2_scaling) { const bool pow_2_scaling) {
bool return_rowwise = rowwise_option == FP8BlockwiseRowwiseOption::ROWWISE; bool return_rowwise = rowwise_option != FP8BlockwiseRowwiseOption::NONE;
bool return_columnwise_transpose = bool return_columnwise_gemm_ready =
columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_TRANSPOSE; columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY;
bool return_columnwise_compact =
columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT;
using SMemVec = Vec<IType, kNVecSMem>; using SMemVec = Vec<IType, kNVecSMem>;
using OVec = Vec<OType, kNVecOut>; using OVec = Vec<OType, kNVecOut>;
...@@ -299,8 +325,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -299,8 +325,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
} }
} }
// Step 3: Transpose, cast and store to output_t // Step 3 (return_columnwise_gemm_ready): Transpose, cast and store to output_t
if (return_columnwise_transpose) { if (return_columnwise_gemm_ready) {
constexpr int c_stride = constexpr int c_stride =
kThreadsPerBlock / kNumThreadsStore; // Stride in columns of shared memory kThreadsPerBlock / kNumThreadsStore; // Stride in columns of shared memory
constexpr int num_iterations = kTileDim / (c_stride * kNVecSMem); constexpr int num_iterations = kTileDim / (c_stride * kNVecSMem);
...@@ -385,6 +411,103 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -385,6 +411,103 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
} }
} }
} }
// Step 4 (return_columnwise_compact): cast in 128x1 style and store to output, skip transpose
if (return_columnwise_compact) {
// thread tile should be 4x16, 16 means 8 smem reads
constexpr int kThreadTileRow = kTileDim / kThreadsPerWarp;
constexpr int kThreadTileCol = kNVecOut;
using RegVec = Vec<IType, kThreadTileCol>;
using RegScaleVec = Vec<CType, kThreadTileCol>;
constexpr int num_smem_reads = kNVecOut / kNVecSMem;
// c_stride will not be used here because we only have one iteration
// constexpr int c_stride = kThreadTileCol * kNumWarps / kNVecSMem;
constexpr int num_iterations =
kTileDim / (kNumWarps * kThreadTileCol); // should be only one iteration
static_assert(num_iterations == 1,
"num_iterations should be 1 for columnwise non-transpose case");
const int thr_idx_in_warp = threadIdx.x % kThreadsPerWarp;
const int warp_idx = threadIdx.x / kThreadsPerWarp;
const int r_s = thr_idx_in_warp * kThreadTileRow; // Row in shared memory
int c_s = warp_idx * num_smem_reads; // Column in shared memory
size_t r_g = static_cast<size_t>(blockIdx.y) * kTileDim + r_s; // Row in global memory
const size_t c_g =
static_cast<size_t>(blockIdx.x) * kTileDim + c_s * kNVecSMem; // Column in global memory
const size_t num_ele = c_g < row_length
? min(static_cast<size_t>(kThreadTileCol), row_length - c_g)
: 0; // For not aligned case
#pragma unroll
for (int iter = 0; iter < num_iterations; ++iter) {
RegVec reg_vec[kThreadTileRow];
RegScaleVec thr_scale;
// Step 3.1: Load from shared memory to registers
#pragma unroll
for (int i = 0; i < kThreadTileRow; ++i) {
int r = r_s + i;
#pragma unroll
for (int j = 0; j < num_smem_reads; ++j) {
int c = c_s + j;
SMemVec smem_vec = smem[r * kSMemCol + c];
// copy smem_vec to reg vec with its elements
#pragma unroll
for (int k = 0; k < kNVecSMem; ++k) {
reg_vec[i].data.elt[j * kNVecSMem + k] = smem_vec.data.elt[k];
}
}
}
#pragma unroll
for (int reg_idx = 0; reg_idx < kThreadTileCol; ++reg_idx) {
// Step 3.2: Compute local amax
CType amax = 0;
#pragma unroll
for (int i = 0; i < kThreadTileRow; ++i) {
amax = fmaxf(amax, fabsf(reg_vec[i].data.elt[reg_idx]));
}
// Step 3.3: Reduce amax
const bool is_src_lane = thr_idx_in_warp == 0;
amax = warp_reduce_max<kThreadsPerWarp>(amax);
constexpr int lane_zero = 0;
amax = __shfl_sync(0xFFFFFFFF, amax, lane_zero);
// Step 3.4: Compute scale
CType scale;
scale = compute_scale_from_types<IType, OType>(amax, epsilon, pow_2_scaling);
thr_scale.data.elt[reg_idx] = scale;
// Step 3.5: Write scale_inv_t
bool write_scale_inv = is_src_lane;
if constexpr (!kAligned) {
write_scale_inv &= (c_g + reg_idx < row_length);
}
if (write_scale_inv) {
CType scale_inv = 1.0 / scale;
size_t row_idx = static_cast<size_t>(blockIdx.y);
size_t col_idx = static_cast<size_t>(blockIdx.x) * kTileDim + c_s * kNVecSMem + reg_idx;
tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv;
}
}
// Step 3.6: Quantize
for (int row_idx = 0; row_idx < kThreadTileRow; ++row_idx) {
OType* output_g =
&output_t[(r_g + row_idx) * row_length + c_g]; // Output address in global memory
OVec output_vec;
#pragma unroll
for (int i = 0; i < kThreadTileCol; ++i) {
output_vec.data.elt[i] = static_cast<OType>(
static_cast<CType>(reg_vec[row_idx].data.elt[i]) * thr_scale.data.elt[i]);
}
// Step 3.7: Store output_t
if constexpr (kAligned) {
output_vec.store_to(output_g);
} else {
if (r_g + row_idx < num_rows) {
output_vec.store_to_elts(output_g, 0, num_ele);
}
}
}
// Step 3.8: Update output address, column index of shared memory
// this section shouldn't matter since we only have one iteration
}
}
} }
} // namespace } // namespace
...@@ -400,11 +523,6 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -400,11 +523,6 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
const bool pow2_scale, cudaStream_t stream) { const bool pow2_scale, cudaStream_t stream) {
NVTE_API_CALL(quantize_transpose_vector_blockwise); NVTE_API_CALL(quantize_transpose_vector_blockwise);
// assert that rowwise_option and columnwise_option are not both NONE
NVTE_CHECK(rowwise_option != FP8BlockwiseRowwiseOption::NONE ||
columnwise_option != FP8BlockwiseColumnwiseOption::NONE,
"rowwise_option and columnwise_option cannot both be NONE");
const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u;
size_t num_elements = row_length; size_t num_elements = row_length;
size_t num_rows = 1; size_t num_rows = 1;
...@@ -425,32 +543,43 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -425,32 +543,43 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
size_t scale_t_stride_y = 0; size_t scale_t_stride_y = 0;
if (rowwise_option != FP8BlockwiseRowwiseOption::NONE) { if (rowwise_option != FP8BlockwiseRowwiseOption::NONE) {
NVTE_CHECK(rowwise_option == FP8BlockwiseRowwiseOption::ROWWISE, NVTE_CHECK(rowwise_option == FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY ||
rowwise_option == FP8BlockwiseRowwiseOption::ROWWISE_COMPACT,
"Unexpected rowwise enum value"); "Unexpected rowwise enum value");
NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape."); NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape.");
NVTE_CHECK(scale_inv.shape.size() == 2, "Scale dimension must be 2."); NVTE_CHECK(scale_inv.shape.size() == 2, "Scale dimension must be 2.");
size_t scale_k = scale_inv.shape[1]; size_t scale_k = scale_inv.shape[1];
scale_stride_x = scale_k; bool rowwise_compact = rowwise_option == FP8BlockwiseRowwiseOption::ROWWISE_COMPACT;
scale_stride_y = 1; scale_stride_x = rowwise_compact ? 1 : scale_k;
scale_stride_y = rowwise_compact ? scale_k : 1;
} }
if (columnwise_option != FP8BlockwiseColumnwiseOption::NONE) { if (columnwise_option != FP8BlockwiseColumnwiseOption::NONE) {
NVTE_CHECK(columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_TRANSPOSE,
"Unexpected columnwise enum value");
NVTE_CHECK(output_t.shape.size() == input.shape.size(), NVTE_CHECK(output_t.shape.size() == input.shape.size(),
"output_t must have same number of dimensions as input."); "output_t must have same number of dimensions as input.");
if (output_t.shape.size() > 0) { if (output_t.shape.size() > 0) {
NVTE_CHECK(output_t.shape[0] == row_length, "Wrong dimension 0 of output_t."); if (columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY) {
for (size_t i = 1; i < output_t.shape.size(); ++i) { NVTE_CHECK(output_t.shape[0] == row_length, "Wrong dimension 0 of output_t.");
NVTE_CHECK(output_t.shape.at(i) == input.shape.at(i - 1), "Wrong dimension in output_t"); for (size_t i = 1; i < output_t.shape.size(); ++i) {
NVTE_CHECK(output_t.shape.at(i) == input.shape.at(i - 1), "Wrong dimension in output_t");
}
} else {
NVTE_CHECK(columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT,
"Unexpected columnwise option enum value");
NVTE_CHECK(output_t.shape[0] == input.shape[0], "Wrong dimension 0 of output_t.");
NVTE_CHECK(
input.shape == output_t.shape,
"Input and output_t must have the same shape for columnwise non-transpose case.");
} }
} }
NVTE_CHECK(output.dtype == output_t.dtype, "output and output_t need to have the same dtype."); NVTE_CHECK(output.dtype == output_t.dtype, "output and output_t need to have the same dtype.");
NVTE_CHECK(scale_inv_t.shape.size() == 2, "Scale_t dimension must be 2."); NVTE_CHECK(scale_inv_t.shape.size() == 2, "Scale_t dimension must be 2.");
scale_t_stride_x = scale_inv_t.shape[1]; bool columnwise_compact = columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT;
scale_t_stride_y = 1; size_t scale_t_k = scale_inv_t.shape[1];
scale_t_stride_x = columnwise_compact ? 1 : scale_t_k;
scale_t_stride_y = columnwise_compact ? scale_t_k : 1;
} }
const size_t num_blocks_x = DIVUP(row_length, (size_t)kTileDim); const size_t num_blocks_x = DIVUP(row_length, (size_t)kTileDim);
......
...@@ -1283,12 +1283,25 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o ...@@ -1283,12 +1283,25 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
"IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D"); "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D");
bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : false; bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : false;
float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f; float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f;
FP8BlockwiseRowwiseOption rowwise_option = output_tensor->has_data() FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE;
? FP8BlockwiseRowwiseOption::ROWWISE FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE;
: FP8BlockwiseRowwiseOption::NONE; if (output_tensor->has_data()) {
FP8BlockwiseColumnwiseOption columnwise_option = bool rowwise_compact = quant_config_cpp
output_tensor->has_columnwise_data() ? FP8BlockwiseColumnwiseOption::COLUMNWISE_TRANSPOSE ? quant_config_cpp->float8_block_scale_tensor_format ==
: FP8BlockwiseColumnwiseOption::NONE; Float8BlockScaleTensorFormat::COMPACT
: false;
rowwise_option = rowwise_compact ? FP8BlockwiseRowwiseOption::ROWWISE_COMPACT
: FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY;
}
if (output_tensor->has_columnwise_data()) {
bool columnwise_compact = quant_config_cpp
? quant_config_cpp->float8_block_scale_tensor_format ==
Float8BlockScaleTensorFormat::COMPACT
: false;
columnwise_option = columnwise_compact
? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT
: FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY;
}
quantize_transpose_vector_blockwise(input_tensor->data, output_tensor->scale_inv, quantize_transpose_vector_blockwise(input_tensor->data, output_tensor->scale_inv,
output_tensor->columnwise_scale_inv, output_tensor->data, output_tensor->columnwise_scale_inv, output_tensor->data,
output_tensor->columnwise_data, epsilon, rowwise_option, output_tensor->columnwise_data, epsilon, rowwise_option,
......
...@@ -75,6 +75,10 @@ ...@@ -75,6 +75,10 @@
.value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \ .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \
.value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) \ .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) \
.value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); \ .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); \
pybind11::enum_<transformer_engine::Float8BlockScaleTensorFormat>( \
m, "Float8BlockScaleTensorFormat", pybind11::module_local()) \
.value("GEMM_READY", transformer_engine::Float8BlockScaleTensorFormat::GEMM_READY) \
.value("COMPACT", transformer_engine::Float8BlockScaleTensorFormat::COMPACT); \
pybind11::enum_<transformer_engine::CommOverlapType>(m, "CommOverlapType", \ pybind11::enum_<transformer_engine::CommOverlapType>(m, "CommOverlapType", \
pybind11::module_local()) \ pybind11::module_local()) \
.value("RS", transformer_engine::CommOverlapType::RS) \ .value("RS", transformer_engine::CommOverlapType::RS) \
......
...@@ -77,6 +77,14 @@ def general_gemm( ...@@ -77,6 +77,14 @@ def general_gemm(
# There is not use_split_accumulator == False # There is not use_split_accumulator == False
# implementation for Float8BlockwiseQTensorBase GEMM # implementation for Float8BlockwiseQTensorBase GEMM
use_split_accumulator = True use_split_accumulator = True
# Check that data format is supported
if (
A._data_format != tex.Float8BlockScaleTensorFormat.GEMM_READY
or B._data_format != tex.Float8BlockScaleTensorFormat.GEMM_READY
):
raise RuntimeError("GEMM with Float8BlockwiseQTensor requires GEMM_READY format")
args = ( args = (
A, A,
transa, # transa transa, # transa
......
...@@ -173,6 +173,8 @@ class Float8BlockQuantizer : public Quantizer { ...@@ -173,6 +173,8 @@ class Float8BlockQuantizer : public Quantizer {
bool force_pow_2_scales = false; bool force_pow_2_scales = false;
// Amax within quantization tile has a floor of epsilon. // Amax within quantization tile has a floor of epsilon.
float amax_epsilon = 0.0; float amax_epsilon = 0.0;
// Whether quantized tensor will be used in an all-gather
bool all_gather_usage = false;
private: private:
int block_scaling_dim = 2; int block_scaling_dim = 2;
......
...@@ -81,6 +81,9 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob ...@@ -81,6 +81,9 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob
auto my_quantizer_bw = static_cast<Float8BlockQuantizer*>(my_quantizer.get()); auto my_quantizer_bw = static_cast<Float8BlockQuantizer*>(my_quantizer.get());
quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales); quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon); quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon);
if (my_quantizer_bw->all_gather_usage) {
quant_config.set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat::COMPACT);
}
} }
NVTE_SCOPED_GIL_RELEASE({ NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_v2(te_input.data(), te_output.data(), quant_config, nvte_quantize_v2(te_input.data(), te_output.data(), quant_config,
......
...@@ -170,6 +170,9 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe ...@@ -170,6 +170,9 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
auto my_quantizer_bw = static_cast<Float8BlockQuantizer *>(my_quantizer.get()); auto my_quantizer_bw = static_cast<Float8BlockQuantizer *>(my_quantizer.get());
quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales); quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon); quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon);
if (my_quantizer_bw->all_gather_usage) {
quant_config.set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat::COMPACT);
}
} }
NVTE_SCOPED_GIL_RELEASE({ NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config, nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config,
...@@ -328,6 +331,9 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w ...@@ -328,6 +331,9 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
auto my_quantizer_bw = static_cast<Float8BlockQuantizer *>(my_quantizer.get()); auto my_quantizer_bw = static_cast<Float8BlockQuantizer *>(my_quantizer.get());
quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales); quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon); quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon);
if (my_quantizer_bw->all_gather_usage) {
quant_config.set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat::COMPACT);
}
} }
NVTE_SCOPED_GIL_RELEASE({ NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config, nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config,
......
...@@ -261,6 +261,7 @@ Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quanti ...@@ -261,6 +261,7 @@ Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quanti
this->amax_epsilon = quantizer.attr("amax_epsilon").cast<float>(); this->amax_epsilon = quantizer.attr("amax_epsilon").cast<float>();
NVTE_CHECK(this->block_scaling_dim == 1 || this->block_scaling_dim == 2, NVTE_CHECK(this->block_scaling_dim == 1 || this->block_scaling_dim == 2,
"Unsupported block scaling dim."); "Unsupported block scaling dim.");
this->all_gather_usage = quantizer.attr("all_gather_usage").cast<bool>();
} }
void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const { void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const {
...@@ -299,6 +300,10 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor( ...@@ -299,6 +300,10 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
size_t m_dim = numel / k_dim; size_t m_dim = numel / k_dim;
constexpr size_t kBlockLen = 128; constexpr size_t kBlockLen = 128;
Float8BlockScaleTensorFormat data_format =
(all_gather_usage ? Float8BlockScaleTensorFormat::COMPACT
: Float8BlockScaleTensorFormat::GEMM_READY);
if (rowwise_usage) { if (rowwise_usage) {
if (rowwise_data.has_value()) { if (rowwise_data.has_value()) {
data_rowwise = std::move(*rowwise_data); data_rowwise = std::move(*rowwise_data);
...@@ -308,16 +313,26 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor( ...@@ -308,16 +313,26 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
size_t sinv0 = 0; size_t sinv0 = 0;
size_t sinv1 = 0; size_t sinv1 = 0;
if (block_scaling_dim == 2) { if (block_scaling_dim == 2) {
// 2D scaling is always GEMM_READY for now
NVTE_CHECK(data_format == Float8BlockScaleTensorFormat::GEMM_READY,
"2D scaling is always GEMM_READY for now.");
sinv0 = (m_dim + kBlockLen - 1) / kBlockLen; sinv0 = (m_dim + kBlockLen - 1) / kBlockLen;
sinv1 = roundup((k_dim + kBlockLen - 1) / kBlockLen, 4); sinv1 = roundup((k_dim + kBlockLen - 1) / kBlockLen, 4);
} else if (block_scaling_dim == 1) { } else if (block_scaling_dim == 1) {
// 1D scaling can be GEMM_READY or COMPACT
bool rowwise_compact = data_format == Float8BlockScaleTensorFormat::COMPACT;
// default rowwise scaling factor shape already transpose the scaling factor so it's GEMM_READY
sinv0 = (k_dim + kBlockLen - 1) / kBlockLen; sinv0 = (k_dim + kBlockLen - 1) / kBlockLen;
sinv1 = roundup(m_dim, 4); sinv1 = rowwise_compact ? m_dim : roundup(m_dim, 4);
// if the rowwise format is compact, the scaling factor is not be transposed
if (rowwise_compact) {
std::swap(sinv0, sinv1);
}
} else { } else {
NVTE_CHECK(false, NVTE_ERROR(
"Unsupported block_scaling_dim in create_tensor rowwise." "Unsupported block_scaling_dim in create_tensor rowwise. "
"Expected 1 or 2. Got ", "Expected 1 or 2. Got ",
block_scaling_dim); block_scaling_dim);
} }
scale_inv_rowwise = scale_inv_rowwise =
at::empty({static_cast<int64_t>(sinv0), static_cast<int64_t>(sinv1)}, scale_opts); at::empty({static_cast<int64_t>(sinv0), static_cast<int64_t>(sinv1)}, scale_opts);
...@@ -332,28 +347,43 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor( ...@@ -332,28 +347,43 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
NVTE_CHECK(torch_shape.size() == shape.size(), "Shape expected to match torch shape. Shape ", NVTE_CHECK(torch_shape.size() == shape.size(), "Shape expected to match torch shape. Shape ",
columnwise_shape, " torch shape: ", torch_columnwise_shape); columnwise_shape, " torch shape: ", torch_columnwise_shape);
if (torch_shape.size() > 0) { if (torch_shape.size() > 0) {
torch_columnwise_shape.reserve(torch_shape.size()); if (!all_gather_usage) {
columnwise_shape.reserve(shape.size()); torch_columnwise_shape.reserve(torch_shape.size());
torch_columnwise_shape.push_back(torch_shape[torch_shape.size() - 1]); columnwise_shape.reserve(shape.size());
columnwise_shape.push_back(shape[shape.size() - 1]); torch_columnwise_shape.push_back(torch_shape[torch_shape.size() - 1]);
for (size_t i = 0; i < torch_shape.size() - 1; ++i) { columnwise_shape.push_back(shape[shape.size() - 1]);
torch_columnwise_shape.push_back(torch_shape[i]); for (size_t i = 0; i < torch_shape.size() - 1; ++i) {
columnwise_shape.push_back(shape[i]); torch_columnwise_shape.push_back(torch_shape[i]);
columnwise_shape.push_back(shape[i]);
}
} else {
// assert we are doing 1D scaling
NVTE_CHECK(block_scaling_dim == 1,
"Compact columnwise format is not supported for 128x128 2D block scaling.");
torch_columnwise_shape = torch_shape;
columnwise_shape = shape;
} }
} }
size_t sinv0 = 0; size_t sinv0 = 0;
size_t sinv1 = 0; size_t sinv1 = 0;
if (block_scaling_dim == 2) { if (block_scaling_dim == 2) {
// 2D scaling is always GEMM_READY for now
NVTE_CHECK(data_format == Float8BlockScaleTensorFormat::GEMM_READY,
"2D scaling is always GEMM_READY for now.");
sinv0 = (k_dim + kBlockLen - 1) / kBlockLen; sinv0 = (k_dim + kBlockLen - 1) / kBlockLen;
sinv1 = roundup((m_dim + kBlockLen - 1) / kBlockLen, 4); sinv1 = roundup((m_dim + kBlockLen - 1) / kBlockLen, 4);
} else if (block_scaling_dim == 1) { } else if (block_scaling_dim == 1) {
bool columnwise_compact = data_format == Float8BlockScaleTensorFormat::COMPACT;
sinv0 = (m_dim + kBlockLen - 1) / kBlockLen; sinv0 = (m_dim + kBlockLen - 1) / kBlockLen;
sinv1 = roundup(k_dim, 4); sinv1 = columnwise_compact ? k_dim : roundup(k_dim, 4);
// GEMM READY case: scaling factor is [sinv0, sinv1], already transposed here for CuBLAS
// for COMPACT case, since we apply 128x1 scaling here without transposing columnwise data, scaling factor is also [sinv0, sinv1]
// so no need to swap sinv0 and sinv1 here
} else { } else {
NVTE_CHECK(false, NVTE_ERROR(
"Unsupported block_scaling_dim in create_tensor columnwise." "Unsupported block_scaling_dim in create_tensor columnwise. "
"Expected 1 or 2. Got ", "Expected 1 or 2. Got ",
block_scaling_dim); block_scaling_dim);
} }
data_colwise = at::empty(torch_columnwise_shape, opts); data_colwise = at::empty(torch_columnwise_shape, opts);
scale_inv_colwise = scale_inv_colwise =
...@@ -373,7 +403,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor( ...@@ -373,7 +403,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
"rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise, "rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise,
"rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_scale_inv"_a = scale_inv_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_scale_inv"_a = scale_inv_colwise,
"fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer, "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer,
"is_2D_scaled"_a = (block_scaling_dim == 2)); "is_2D_scaled"_a = (block_scaling_dim == 2), "data_format"_a = data_format);
} else { } else {
py::handle Float8BlockwiseQTensorClass( py::handle Float8BlockwiseQTensorClass(
reinterpret_cast<PyObject*>(Float8BlockwiseQTensorPythonClass)); reinterpret_cast<PyObject*>(Float8BlockwiseQTensorPythonClass));
...@@ -381,7 +411,8 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor( ...@@ -381,7 +411,8 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
"shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), "rowwise_data"_a = data_rowwise, "shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), "rowwise_data"_a = data_rowwise,
"columnwise_data"_a = data_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_data"_a = data_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise,
"columnwise_scale_inv"_a = scale_inv_colwise, "fp8_dtype"_a = this->dtype, "columnwise_scale_inv"_a = scale_inv_colwise, "fp8_dtype"_a = this->dtype,
"quantizer"_a = this->quantizer, "is_2D_scaled"_a = (block_scaling_dim == 2)); "quantizer"_a = this->quantizer, "is_2D_scaled"_a = (block_scaling_dim == 2),
"data_format"_a = data_format);
} }
return {std::move(tensor), std::move(ret)}; return {std::move(tensor), std::move(ret)};
......
...@@ -8,6 +8,7 @@ from __future__ import annotations ...@@ -8,6 +8,7 @@ from __future__ import annotations
from collections.abc import Iterable from collections.abc import Iterable
from contextlib import contextmanager, AbstractContextManager, ContextDecorator from contextlib import contextmanager, AbstractContextManager, ContextDecorator
from functools import lru_cache from functools import lru_cache
from dataclasses import dataclass
import math import math
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import warnings import warnings
...@@ -19,6 +20,15 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP ...@@ -19,6 +20,15 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._common_utils import _get_module_fsdp_state from torch.distributed.fsdp._common_utils import _get_module_fsdp_state
from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_modules from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_modules
try:
import torch.distributed._symmetric_memory as symm_mem
HAS_TORCH_SYMMETRIC = True
except ImportError:
HAS_TORCH_SYMMETRIC = False
import transformer_engine_torch as tex
from . import torch_version from . import torch_version
from .utils import ( from .utils import (
is_non_tn_fp8_gemm_supported, is_non_tn_fp8_gemm_supported,
...@@ -34,14 +44,8 @@ from .tensor.quantized_tensor import QuantizedTensor, Quantizer ...@@ -34,14 +44,8 @@ from .tensor.quantized_tensor import QuantizedTensor, Quantizer
from .tensor._internal.float8_tensor_base import Float8TensorBase from .tensor._internal.float8_tensor_base import Float8TensorBase
from .tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from .tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from .tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from .tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor from ..debug.pytorch.debug_quantization import DebugQuantizedTensor, DebugQuantizer
try:
import torch.distributed._symmetric_memory as symm_mem
HAS_TORCH_SYMMETRIC = True
except ImportError:
HAS_TORCH_SYMMETRIC = False
__all__ = ["checkpoint", "CudaRNGStatesTracker"] __all__ = ["checkpoint", "CudaRNGStatesTracker"]
...@@ -977,6 +981,67 @@ def _all_gather_fp8( ...@@ -977,6 +981,67 @@ def _all_gather_fp8(
return out, handle return out, handle
def _set_quantizer_format(quantizer: Quantizer, compact: bool = False) -> None:
"""Make quantizer compact"""
_quantizer = quantizer
if isinstance(quantizer, DebugQuantizer):
_quantizer = quantizer.parent_quantizer
if isinstance(_quantizer, Float8BlockQuantizer):
_quantizer.all_gather_usage = compact
def _post_process_fp8_blockwise_gather(
out: Float8BlockwiseQTensorBase,
quantizer: Float8BlockQuantizer,
handle: Optional[torch.distributed.Work] = None,
) -> Float8BlockwiseQTensorBase:
"""Post-process FP8 blockwise gather."""
if handle is not None:
handle.wait()
handle = None
if out._is_gemm_ready_format():
return out
needs_columnwise_data_transpose = (
quantizer is not None and quantizer.columnwise_usage and not is_non_tn_fp8_gemm_supported()
)
need_rowwise_scale_transpose = (
quantizer is not None and quantizer.rowwise_usage and not is_non_tn_fp8_gemm_supported()
)
# CuBLAS requires transpose of the scale inv tensor, suppose orig input is 256x1024
# columnwise compact format means doing 128x1 quantization of it
# so quantized tensor is 256x1024, scale inv is 2x1024
# If we were doing GEMM_READY format, then it's equivalent to do 1x128 quantization
# on a transposed 1024x256 tensor, so scale inv is 1024x2, cublas requries 2x1024
# Thereforce, it turns out we don't need to transpose the scale inv, only columnwise data
if needs_columnwise_data_transpose:
out._transpose_columnwise_data()
if need_rowwise_scale_transpose:
out._rowwise_scale_inv = out._rowwise_scale_inv.transpose(-2, -1).contiguous()
out._data_format = tex.Float8BlockScaleTensorFormat.GEMM_READY
return out
@dataclass
class _FP8BlockwiseAllGatherAsyncHandle:
"""Handle for asynchronous FP8 blockwise all-gather."""
tensor: Float8BlockwiseQTensorBase
quantizer: Float8BlockQuantizer
async_handle: torch.distributed.Work
_synchronized: bool = False
def wait(self) -> None:
"""Wait for the async operation to complete and post-process the tensor."""
if self._synchronized:
return
self.async_handle.wait()
_post_process_fp8_blockwise_gather(self.tensor, self.quantizer)
self._synchronized = True
def _all_gather_fp8_blockwise( def _all_gather_fp8_blockwise(
inp: torch.Tensor, inp: torch.Tensor,
process_group: dist_group_type, process_group: dist_group_type,
...@@ -990,8 +1055,9 @@ def _all_gather_fp8_blockwise( ...@@ -990,8 +1055,9 @@ def _all_gather_fp8_blockwise(
Returns: quantizer(gather(inp)) Returns: quantizer(gather(inp))
NOTE: The implementation is not sophisticated enough to honor async_op=True. NOTE: The implementation is only going to honor async_op=True for FP8 gather case.
In some cases it falls back to synchronous gather and invokes the quantizer. In the case where tensor shape is not divisible by 128, the implementation will fall back
to synchronous gather and invoke the quantizer.
""" """
# Input tensor attributes # Input tensor attributes
...@@ -1027,7 +1093,11 @@ def _all_gather_fp8_blockwise( ...@@ -1027,7 +1093,11 @@ def _all_gather_fp8_blockwise(
out_shape[0] *= world_size out_shape[0] *= world_size
# Doing BF16 gather for now as baseline because it's simpler # Doing BF16 gather for now as baseline because it's simpler
if not isinstance(inp, Float8BlockwiseQTensorBase) and quantizer is not None: if (
not isinstance(inp, Float8BlockwiseQTensorBase)
and quantizer is not None
and not quantizer.is_quantizable(inp)
):
out = torch.empty( out = torch.empty(
out_shape, out_shape,
dtype=dtype, dtype=dtype,
...@@ -1035,14 +1105,93 @@ def _all_gather_fp8_blockwise( ...@@ -1035,14 +1105,93 @@ def _all_gather_fp8_blockwise(
memory_format=torch.contiguous_format, memory_format=torch.contiguous_format,
) )
torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False) torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False)
orig_all_gather_usage = quantizer.all_gather_usage
quantizer.all_gather_usage = False
out = quantizer(out) out = quantizer(out)
quantizer.all_gather_usage = orig_all_gather_usage
return out, None return out, None
# Implementation of fp8 gather needs to account for: # Implementation of fp8 gather needs to account for:
# * Getting columnwise data as a transpose of how it is stored for GEMMS. # * Getting columnwise data as a transpose of how it is stored for GEMMS.
# * Gathering non GEMM swizzled scales. # * Gathering non GEMM swizzled scales.
# * Refer to scaffold code when implementing at:
# https://github.com/kwyss-nvidia/TransformerEngine/commit/6659ee9dc84fb515d1d47699d8bfd20a72b76477 # Cast input tensor to Float8BlockwiseQTensor with required data
raise NotImplementedError("fp8 blockwise allgather not yet implemented") # Set to compact usage in case the quantizer is not correctly configured
orig_all_gather_usage = quantizer.all_gather_usage
quantizer.all_gather_usage = True
if not isinstance(inp, Float8BlockwiseQTensorBase):
inp = quantizer(inp)
elif (quantizer.rowwise_usage and inp._rowwise_data is None) or (
quantizer.columnwise_usage and inp._columnwise_data is None
):
warnings.warn(
"Input and quantizer do not have matching usages. "
"Dequantizing and requantizing to Float8BlockwiseQTensor."
)
inp = quantizer(inp.dequantize())
quantizer.all_gather_usage = orig_all_gather_usage
# Begin to do network communication, need to make sure compact format
if inp._data_format != tex.Float8BlockScaleTensorFormat.COMPACT:
raise RuntimeError(
"All-gather with FP8 block-wise quantized tensor requires compact data format, "
f"but found data_format={inp._data_format}"
)
# Construct Float8BlockwiseQTensor output tensor
out = quantizer.make_empty(out_shape, dtype=dtype, device=device)
# Coalesce NCCL collectives
with torch.distributed._coalescing_manager(
group=process_group,
device=device,
async_ops=async_op,
) as coalescing_manager:
# Gather Float8BlockwiseQTensor data for row-wise usage
if quantizer.rowwise_usage:
# Launch all-gathers
torch.distributed.all_gather_into_tensor(
out._rowwise_scale_inv,
inp._rowwise_scale_inv,
group=process_group,
)
torch.distributed.all_gather_into_tensor(
out._rowwise_data,
inp._rowwise_data,
group=process_group,
)
# Gather Float8BlockwiseQTensor data for column-wise usage
if quantizer.columnwise_usage:
# Launch all-gathers
torch.distributed.all_gather_into_tensor(
out._columnwise_scale_inv,
inp._columnwise_scale_inv,
group=process_group,
)
torch.distributed.all_gather_into_tensor(
out._columnwise_data,
inp._columnwise_data,
group=process_group,
)
handle = coalescing_manager if async_op else None
# Unlike MXFP8, this fp8 blockwise tensor primarily works with Hopper
# This means that we need to transpose the gathered columnwise data
# Example usage is grad_output tensor, ie. dY in linear backward
# We want to gather two FP8 tensors (rowwise and columnwise) along dim0
# and then transpose the columnwise data to match the rowwise data
# Make sure FP8 transpose is populated if needed
if async_op:
handle = _FP8BlockwiseAllGatherAsyncHandle(out, quantizer, handle)
else:
# if it's a sync op, we need to do the transpose here as post processing step
_post_process_fp8_blockwise_gather(out, quantizer, handle)
return out, handle
def _all_gather_mxfp8( def _all_gather_mxfp8(
...@@ -1267,6 +1416,9 @@ def gather_along_first_dim( ...@@ -1267,6 +1416,9 @@ def gather_along_first_dim(
) )
if isinstance(inp, QuantizedTensor): if isinstance(inp, QuantizedTensor):
inp = inp.dequantize() inp = inp.dequantize()
# Falling back to high-precision all-gather for Float8BlockQuantizer
# means that it should directly output GEMM_READY format
_set_quantizer_format(quantizer, compact=False)
out = torch.empty( out = torch.empty(
out_shape, out_shape,
dtype=inp.dtype, dtype=inp.dtype,
......
...@@ -63,10 +63,11 @@ from ..tensor.quantized_tensor import ( ...@@ -63,10 +63,11 @@ from ..tensor.quantized_tensor import (
) )
from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.utils import any_feature_enabled from ...debug.pytorch.utils import any_feature_enabled
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ..cpp_extensions import ( from ..cpp_extensions import (
...@@ -183,12 +184,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -183,12 +184,6 @@ class _LayerNormLinear(torch.autograd.Function):
# All-gather is not supported with FP8 column-wise data # All-gather is not supported with FP8 column-wise data
input_quantizer.set_usage(columnwise=False) input_quantizer.set_usage(columnwise=False)
# Do TP communication in high precision if quantized format
# does not support communication
force_hp_blockwise_ln_out_gather = (
fp8 and with_input_all_gather and isinstance(input_quantizer, Float8BlockQuantizer)
)
# Avoid quantized norm kernel if norm output will be returned # Avoid quantized norm kernel if norm output will be returned
# or if a gather of ln_out must be in high precision. # or if a gather of ln_out must be in high precision.
with_quantized_norm = ( with_quantized_norm = (
...@@ -196,7 +191,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -196,7 +191,6 @@ class _LayerNormLinear(torch.autograd.Function):
and not debug and not debug
and not return_layernorm_output and not return_layernorm_output
and not return_layernorm_output_gathered and not return_layernorm_output_gathered
and not force_hp_blockwise_ln_out_gather
) )
# Apply normalization # Apply normalization
...@@ -233,15 +227,16 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -233,15 +227,16 @@ class _LayerNormLinear(torch.autograd.Function):
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
ln_out_return = ln_out_total ln_out_return = ln_out_total
if fp8 or debug: if fp8 or debug:
if not force_hp_blockwise_ln_out_gather: ln_out = input_quantizer(ln_out)
ln_out = input_quantizer(ln_out)
input_quantizer.set_usage(rowwise=True, columnwise=False) input_quantizer.set_usage(rowwise=True, columnwise=False)
if isinstance(input_quantizer, Float8BlockQuantizer):
input_quantizer.all_gather_usage = False
ln_out_total = input_quantizer(ln_out_total) ln_out_total = input_quantizer(ln_out_total)
else: else:
quantizer = None quantizer = None
if fp8 or debug: if fp8 or debug:
quantizer = input_quantizer quantizer = input_quantizer
if not with_quantized_norm and not force_hp_blockwise_ln_out_gather: if not with_quantized_norm:
ln_out = quantizer(ln_out) ln_out = quantizer(ln_out)
quantizer.set_usage(rowwise=True, columnwise=False) quantizer.set_usage(rowwise=True, columnwise=False)
if ub_overlap_ag_fprop: # Initialize Userbuffers all-gather if ub_overlap_ag_fprop: # Initialize Userbuffers all-gather
...@@ -391,7 +386,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -391,7 +386,6 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.ln_out_needs_gather = ( ctx.ln_out_needs_gather = (
weight.requires_grad and parallel_mode == "column" and sequence_parallel weight.requires_grad and parallel_mode == "column" and sequence_parallel
) )
ctx.force_hp_blockwise_ln_out_gather = force_hp_blockwise_ln_out_gather
# Input with column-wise usage is needed for wgrad GEMM. # Input with column-wise usage is needed for wgrad GEMM.
if backward_needs_input: if backward_needs_input:
...@@ -399,7 +393,10 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -399,7 +393,10 @@ class _LayerNormLinear(torch.autograd.Function):
# For sequence parallel in vanilla FP8, rowwise data is # For sequence parallel in vanilla FP8, rowwise data is
# to gather the input. For MXFP8, columnwise only data # to gather the input. For MXFP8, columnwise only data
# can be allgathered. # can be allgathered.
if isinstance(ln_out, MXFP8TensorBase) or not ctx.ln_out_needs_gather: if (
isinstance(ln_out, (MXFP8TensorBase, Float8BlockwiseQTensorBase))
or not ctx.ln_out_needs_gather
):
ln_out.update_usage(rowwise_usage=False) ln_out.update_usage(rowwise_usage=False)
# Weight with column-wise usage is needed for dgrad GEMM. # Weight with column-wise usage is needed for dgrad GEMM.
...@@ -496,8 +493,8 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -496,8 +493,8 @@ class _LayerNormLinear(torch.autograd.Function):
if return_layernorm_output: if return_layernorm_output:
if return_layernorm_output_gathered: if return_layernorm_output_gathered:
shape = list(inp_shape) shape = list(inp.shape)
shape[0] *= tp_size shape[0] *= tp_size if with_input_all_gather else 1
return out, ln_out_return.view(shape) return out, ln_out_return.view(shape)
return out, ln_out_return.view(inp_shape) return out, ln_out_return.view(inp_shape)
return out return out
...@@ -631,7 +628,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -631,7 +628,7 @@ class _LayerNormLinear(torch.autograd.Function):
ln_out_total_work = None ln_out_total_work = None
if ctx.ln_out_needs_gather: if ctx.ln_out_needs_gather:
quantizer = None quantizer = None
if ctx.input_quantizer is not None and not ctx.force_hp_blockwise_ln_out_gather: if ctx.input_quantizer is not None:
quantizer = ctx.input_quantizer quantizer = ctx.input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
# If data is in FP8, we compute FP8 transposes manually # If data is in FP8, we compute FP8 transposes manually
...@@ -1376,6 +1373,8 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1376,6 +1373,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
recipe = FP8GlobalStateManager.get_fp8_recipe() recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling(): if recipe.float8_current_scaling():
self._customize_quantizers_float8_current_scaling(fwd, recipe) self._customize_quantizers_float8_current_scaling(fwd, recipe)
elif recipe.float8_block_scaling():
self._customize_quantizers_float8_blockwise_scaling(fwd, recipe)
# elif other recipes (mxfp8, etc) # elif other recipes (mxfp8, etc)
def reset_layer_norm_parameters(self) -> None: def reset_layer_norm_parameters(self) -> None:
...@@ -1677,3 +1676,14 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1677,3 +1676,14 @@ class LayerNormLinear(TransformerEngineBaseModule):
weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
weight_quantizer.internal = True weight_quantizer.internal = True
return [weight_quantizer] return [weight_quantizer]
def _customize_quantizers_float8_blockwise_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on blockwise scaling recipe + layernorm_linear."""
assert (
recipe.float8_block_scaling()
), "blockwise scaling recipe quantizer customization here"
if fwd:
if self.sequence_parallel and self.parallel_mode == "column":
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].all_gather_usage = True
...@@ -243,26 +243,18 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -243,26 +243,18 @@ class _LayerNormMLP(torch.autograd.Function):
# All-gather is not supported with FP8 column-wise data # All-gather is not supported with FP8 column-wise data
fc1_input_quantizer.set_usage(columnwise=False) fc1_input_quantizer.set_usage(columnwise=False)
# Do TP communication in high precision if quantized format
# does not support communication
force_hp_fc1_input_gather = (
fp8 and sequence_parallel and isinstance(fc1_input_quantizer, Float8BlockQuantizer)
)
# for fp8 DelayedScaling: layernorm output = FP8 # for fp8 DelayedScaling: layernorm output = FP8
# only output of the linear is returned # only output of the linear is returned
# for return_layernorm_output: layernorm output = High precision, then cast to FP8 # for return_layernorm_output: layernorm output = High precision, then cast to FP8
# high precision layernorm output and output of the linear are returned # high precision layernorm output and output of the linear are returned
# for debug: : layernorm output = High precision to enable processing of this norm # for debug: : layernorm output = High precision to enable processing of this norm
with_quantized_norm = ( with_quantized_norm = (
fp8 fp8
and not debug
and not return_layernorm_output and not return_layernorm_output
and not return_layernorm_output_gathered and not return_layernorm_output_gathered
and not debug
) )
if isinstance(fc1_input_quantizer, Float8BlockQuantizer):
# Kernels not available for norm fusion.
with_quantized_norm = False
# Apply normalization # Apply normalization
ln_out, mu, rsigma = apply_normalization( ln_out, mu, rsigma = apply_normalization(
...@@ -292,15 +284,16 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -292,15 +284,16 @@ class _LayerNormMLP(torch.autograd.Function):
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
ln_out_return = ln_out_total ln_out_return = ln_out_total
if fp8 or debug: if fp8 or debug:
if not force_hp_fc1_input_gather: ln_out = fc1_input_quantizer(ln_out)
ln_out = fc1_input_quantizer(ln_out)
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
if isinstance(fc1_input_quantizer, Float8BlockQuantizer):
fc1_input_quantizer.all_gather_usage = False
ln_out_total = fc1_input_quantizer(ln_out_total) ln_out_total = fc1_input_quantizer(ln_out_total)
else: else:
quantizer = None quantizer = None
if fp8 or debug: if fp8 or debug:
quantizer = fc1_input_quantizer quantizer = fc1_input_quantizer
if not with_quantized_norm and not force_hp_fc1_input_gather: if not with_quantized_norm:
ln_out = fc1_input_quantizer(ln_out) ln_out = fc1_input_quantizer(ln_out)
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
if ub_overlap_ag: if ub_overlap_ag:
...@@ -566,7 +559,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -566,7 +559,6 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.tensor_objects = tensor_objects ctx.tensor_objects = tensor_objects
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
ctx.force_hp_fc1_input_gather = force_hp_fc1_input_gather
ctx.fc1_grad_input_quantizer = fc1_grad_input_quantizer ctx.fc1_grad_input_quantizer = fc1_grad_input_quantizer
ctx.fc1_grad_weight_quantizer = fc1_grad_weight_quantizer ctx.fc1_grad_weight_quantizer = fc1_grad_weight_quantizer
ctx.fc1_grad_output_quantizer = fc1_grad_output_quantizer ctx.fc1_grad_output_quantizer = fc1_grad_output_quantizer
...@@ -627,7 +619,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -627,7 +619,7 @@ class _LayerNormMLP(torch.autograd.Function):
if return_layernorm_output: if return_layernorm_output:
if return_layernorm_output_gathered: if return_layernorm_output_gathered:
shape = list(inp_shape) shape = list(inp_shape)
shape[0] *= tp_size shape[0] *= tp_size if (sequence_parallel and set_parallel_mode) else 1
return fc2_out, ln_out_return.view(shape) return fc2_out, ln_out_return.view(shape)
return fc2_out, ln_out_return.view(inp_shape) return fc2_out, ln_out_return.view(inp_shape)
return fc2_out return fc2_out
...@@ -742,7 +734,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -742,7 +734,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub_obj_fc1_dgrad = None ub_obj_fc1_dgrad = None
if ctx.fc1_weight_requires_grad and ctx.tensor_parallel and ctx.sequence_parallel: if ctx.fc1_weight_requires_grad and ctx.tensor_parallel and ctx.sequence_parallel:
quantizer = None quantizer = None
if ctx.fp8 or ctx.debug and not ctx.force_hp_fc1_input_gather: if ctx.fp8 or ctx.debug:
quantizer = ctx.fc1_input_quantizer quantizer = ctx.fc1_input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
# If data is in FP8, we compute FP8 transposes manually # If data is in FP8, we compute FP8 transposes manually
...@@ -1643,8 +1635,11 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1643,8 +1635,11 @@ class LayerNormMLP(TransformerEngineBaseModule):
super().set_meta_tensor(fwd, recipe) super().set_meta_tensor(fwd, recipe)
# customize quantizers based on each recipe & layer configs # customize quantizers based on each recipe & layer configs
if FP8GlobalStateManager.get_fp8_recipe().float8_current_scaling(): recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling():
self._customize_quantizers_float8_current_scaling(fwd, recipe) self._customize_quantizers_float8_current_scaling(fwd, recipe)
elif recipe.float8_block_scaling():
self._customize_quantizers_float8_blockwise_scaling(fwd, recipe)
# elif for other recipes (mxfp8, etc.) # elif for other recipes (mxfp8, etc.)
def reset_layer_norm_parameters(self) -> None: def reset_layer_norm_parameters(self) -> None:
...@@ -1996,6 +1991,22 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1996,6 +1991,22 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_weight_quantizer.internal = True fc2_weight_quantizer.internal = True
return [fc1_weight_quantizer, fc2_weight_quantizer] return [fc1_weight_quantizer, fc2_weight_quantizer]
def _customize_quantizers_float8_blockwise_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on blockwise scaling recipe + layernorm_mlp."""
assert (
recipe.float8_block_scaling()
), "blockwise scaling recipe quantizer customization here"
if fwd:
if self.sequence_parallel and self.set_parallel_mode:
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].all_gather_usage = True
else:
if self.sequence_parallel and self.set_parallel_mode:
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT2
].all_gather_usage = True
def backward_dw(self): def backward_dw(self):
""" """
Execute the delayed weight gradient computation. Execute the delayed weight gradient computation.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment