Unverified Commit 9985b02c authored by Zhongbo Zhu's avatar Zhongbo Zhu Committed by GitHub
Browse files

[PyTorch] FP8 Subchannel Recipe With FP8 Gather And Configurable Scaling...


[PyTorch] FP8 Subchannel Recipe With FP8 Gather And Configurable Scaling Factor Tensor Swizzling (#1707)

* functional kernel for columnwise + no-transpose option, still hacky
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* pass all quantizer unit tests
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* refactor, add gemm ready api
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* make format options private members, simplify api
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* swizzle scales right before gemm
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* bug fix of single layer test
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* attempt to fix lint issue
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fp8 gather pass, need minor refine
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix return_layernorm_output_gathered case
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* remove special cases, add sanity check before gemm
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* fix lint
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* lint ungrouped imports
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Implement dequantize for compact 1D blocks.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* add more unit test with dequantize compact supported
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* lint again
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* make ag for subchannel respect async
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* zero tolerance in distributed test
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* fix zero tolerance test
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* resolve rebase issues
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* lint & format
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* fix lint
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* clean up
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* bug fix
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* relax rtol for fp32 distributed test
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* fix some ci issue
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* fix ci test failure in debug mode
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* Force row-wise and column-wise data to have same data format

Prototype "all-gather usage" in quantizer.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove dead logic for high-precision AGs
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Debug FP8 block-wise tensor tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug distributed test
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Handle case where LayerNormLinear returns gathered norm output
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* fix debug mode
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

---------
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarKeith Wyss <kwyss@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 6123d7e0
......@@ -100,11 +100,15 @@ def main(argv=None, namespace=None):
# Quantization scheme
QUANTIZATION = args.quantization
if QUANTIZATION in ("fp8", "mxfp8", "fp8_block_scaling"):
global SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE
if QUANTIZATION in ("fp8", "mxfp8"):
SEQ_LEN = 32
BATCH_SIZE = 32
HIDDEN_SIZE = 128
elif QUANTIZATION == "fp8_block_scaling":
SEQ_LEN = 128
BATCH_SIZE = 128
HIDDEN_SIZE = 512
test_dict = [
test_quantizer,
......@@ -185,7 +189,7 @@ def _get_tolerances(dtype):
if dtype == torch.bfloat16:
return {"rtol": 1.6e-2, "atol": 1e-5}
if dtype == torch.float32:
return {"rtol": 1e-4, "atol": 1e-4}
return {"rtol": 1.2e-4, "atol": 1e-4}
raise ValueError(f"Unsupported dtype ({dtype})")
......@@ -649,7 +653,7 @@ def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs
if "return_layernorm_output" in kwargs:
output_single_node, norm_s = output_single_node
output_distributed, norm_d = output_distributed
if sequence_parallel:
if sequence_parallel and not kwargs.get("return_layernorm_output_gathered", False):
norm_d = _gather(norm_d)
_check_outputs(norm_s, norm_d)
......@@ -758,7 +762,7 @@ def _test_layernorm_mlp(set_parallel_mode=None, sequence_parallel=False, **kwarg
if "return_layernorm_output" in kwargs:
output_single_node, norm_s = output_single_node
output_distributed, norm_d = output_distributed
if sequence_parallel:
if sequence_parallel and not kwargs.get("return_layernorm_output_gathered", False):
norm_d = _gather(norm_d)
_check_outputs(norm_s, norm_d)
......
......@@ -260,6 +260,7 @@ class BlockwiseQuantizerReference:
eps: float = 0.0,
pow_2_scales: bool = False,
quant_tile_shape: Tuple[int, int] = (128, 128),
munge_scale_shapes: bool = True,
) -> QuantizeResult:
# sanity checks
assert x.dim() == 2
......@@ -277,27 +278,33 @@ class BlockwiseQuantizerReference:
assert quant_tile_shape in ((1, 128), (128, 128))
if quant_tile_shape[0] == 1:
# Quantize row-wise
return self.scale_munger.munge_scale_shapes_for_backend(
self._quantize_vector_tiling(
result = self._quantize_vector_tiling(
x,
quant_dtype,
tile_len=quant_tile_shape[1],
return_transpose=return_transpose,
pow_2_scales=pow_2_scales,
eps=eps,
),
)
if munge_scale_shapes:
result = self.scale_munger.munge_scale_shapes_for_backend(
result,
quant_tile_shape,
)
return result
else:
# Quantize block-wise
return self.scale_munger.munge_scale_shapes_for_backend(
self._quantize_square_block_tiling(
result = self._quantize_square_block_tiling(
x,
quant_dtype,
tile_len=quant_tile_shape[0],
return_transpose=return_transpose,
pow_2_scales=pow_2_scales,
eps=eps,
),
)
if munge_scale_shapes:
result = self.scale_munger.munge_scale_shapes_for_backend(
result,
quant_tile_shape,
)
return result
......@@ -88,6 +88,126 @@ def initialize_for_many_scales(
return result
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
@pytest.mark.parametrize(
"M, N",
[
# full tile cases
(128, 128),
(256, 256),
(256, 1024),
(1024, 256),
# Padding required cases
(256, 272),
(303, 300),
(305, 256),
# Some larger tiles.
(2000, 2000),
(2048, 2000),
(2000, 1024),
(2048, 1024),
],
)
@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str)
@pytest.mark.parametrize("eps", [0], ids=["eps_0"])
@pytest.mark.parametrize("pow_2_scales", [True], ids=["pow2scales"])
def test_quantization_1D_block_tiling_with_compact_data_and_scales(
x_dtype: torch.dtype,
M: int,
N: int,
quant_dtype: torch.dtype,
eps: float,
pow_2_scales: bool,
) -> None:
te_dtype = TE_DType[quant_dtype]
tile_size = (1, 128)
# This test runs a comparison of the ref class versus the class using
# CUDA kernels to quantize. They should quantize identically for pixels
# that are not DC values in the scale factor shape.
ref_quantizer = BlockwiseQuantizerReference()
sut_quantizer = Float8BlockQuantizer(
fp8_dtype=te_dtype,
rowwise=True,
columnwise=True,
amax_epsilon=eps,
force_pow_2_scales=pow_2_scales,
block_scaling_dim=1,
all_gather_usage=True,
)
# Setup device and random seed
device = "cuda"
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Input
x = initialize_for_many_scales((M, N), tile_size, dtype=x_dtype, device=device)
x_fp8_sut = sut_quantizer.make_empty((M, N), dtype=x_dtype, device=device, requires_grad=False)
x_fp8_sut = sut_quantizer.update_quantized(x, x_fp8_sut)
x_fp8_sut_cpp_alloc = sut_quantizer(x)
assert x_fp8_sut._rowwise_data is not None
qx: torch.Tensor = x_fp8_sut._rowwise_data.view(dtype=quant_dtype)
assert x_fp8_sut._rowwise_scale_inv is not None
sx: torch.Tensor = x_fp8_sut._rowwise_scale_inv
qx_t = x_fp8_sut._columnwise_data
sx_t = x_fp8_sut._columnwise_scale_inv
qresult_ref = ref_quantizer.quantize(
x,
quant_dtype=quant_dtype,
return_transpose=True,
eps=eps,
pow_2_scales=pow_2_scales,
quant_tile_shape=tile_size,
munge_scale_shapes=False,
)
qx_ref, sx_ref, qx_t_ref, sx_t_ref = (
qresult_ref.data,
qresult_ref.scale,
qresult_ref.data_t,
qresult_ref.scale_t,
)
# match the reference quantize transpose output with the columnwise non-transpose method
qx_t_ref = qx_t_ref.transpose(-1, -2).contiguous()
sx_t_ref = sx_t_ref.transpose(-1, -2).contiguous()
# Check
torch.testing.assert_close(qx.float(), qx_ref.float(), atol=0.0, rtol=0.0)
torch.testing.assert_close(sx, sx_ref, atol=0.0, rtol=0.0)
assert qx_t is not None
qx_t = qx_t.view(dtype=quant_dtype)
assert qx_t_ref is not None
assert sx_t is not None
assert sx_t_ref is not None
torch.testing.assert_close(qx_t.float(), qx_t_ref.float(), atol=0.0, rtol=0.0)
torch.testing.assert_close(sx_t, sx_t_ref, atol=0.0, rtol=0.0)
# check that the C++ and Python allocators are equivalent
torch.testing.assert_close(
x_fp8_sut._rowwise_data, x_fp8_sut_cpp_alloc._rowwise_data, atol=0.0, rtol=0.0
)
torch.testing.assert_close(
x_fp8_sut._rowwise_scale_inv, x_fp8_sut_cpp_alloc._rowwise_scale_inv, atol=0.0, rtol=0.0
)
torch.testing.assert_close(
x_fp8_sut._columnwise_data, x_fp8_sut_cpp_alloc._columnwise_data, atol=0.0, rtol=0.0
)
torch.testing.assert_close(
x_fp8_sut._columnwise_scale_inv,
x_fp8_sut_cpp_alloc._columnwise_scale_inv,
atol=0.0,
rtol=0.0,
)
# check if the fp8 output between C++ and Python are the same
assert x_fp8_sut._data_format == x_fp8_sut_cpp_alloc._data_format
def check_quantization_block_tiling_versus_reference(
x_dtype: torch.dtype,
M: int,
......
......@@ -385,7 +385,7 @@ class TestFP8RecipeLinearBase:
)
# recipe1
using_fp8_recipe = recipe1 != GetRecipes.none
using_fp8_recipe = recipe1() != GetRecipes.none()
if using_fp8_recipe:
with fp8_autocast(enabled=True, fp8_recipe=recipe1()):
y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_linear(x, w, bias, gradient)
......@@ -393,7 +393,7 @@ class TestFP8RecipeLinearBase:
y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_linear(x, w, bias, gradient)
# recipe2
using_fp8_recipe = recipe2 != GetRecipes.none
using_fp8_recipe = recipe2() != GetRecipes.none()
if using_fp8_recipe:
with fp8_autocast(enabled=True, fp8_recipe=recipe2()):
y_q, dgrad, wgrad, bgrad = self.run_linear(x, w, bias, gradient)
......@@ -608,7 +608,7 @@ class TestFP8RecipeLayerNormLinearBase(TestFP8RecipeLinearBase):
)
# recipe1
using_fp8_recipe = recipe1 != GetRecipes.none
using_fp8_recipe = recipe1() != GetRecipes.none()
if using_fp8_recipe:
with fp8_autocast(enabled=True, fp8_recipe=recipe1()):
y_q_ref, ln_out_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_layernorm_linear(
......@@ -630,7 +630,7 @@ class TestFP8RecipeLayerNormLinearBase(TestFP8RecipeLinearBase):
)
# recipe2
using_fp8_recipe = recipe2 != GetRecipes.none
using_fp8_recipe = recipe2() != GetRecipes.none()
if using_fp8_recipe:
with fp8_autocast(enabled=True, fp8_recipe=recipe2()):
y_q, ln_out, dgrad, wgrad, bgrad = self.run_layernorm_linear(
......
......@@ -176,7 +176,40 @@ class TestFloat8BlockwiseTensor:
)
@pytest.mark.parametrize("block_scaling_dim", [1, 2])
@pytest.mark.parametrize("dq_columnwise", [True, False])
@pytest.mark.parametrize("all_gather_usage", [True, False])
def test_quantize_dequantize_dims(
self,
dims: DimsType,
block_scaling_dim: int,
dq_columnwise: bool,
all_gather_usage: bool,
) -> None:
if all_gather_usage and block_scaling_dim != 1:
pytest.skip("all_gather_usage only implemented for 1D block quantization.")
atol = _tols[tex.DType.kFloat8E4M3]["atol"]
rtol = _tols[tex.DType.kFloat8E4M3]["rtol"]
quantizer = Float8BlockQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
rowwise=True,
columnwise=dq_columnwise,
block_scaling_dim=block_scaling_dim,
all_gather_usage=all_gather_usage,
)
self._test_quantize_dequantize(
quantizer=quantizer,
dims=dims,
atol=atol,
rtol=rtol,
dequant_columnwise=dq_columnwise,
)
@pytest.mark.parametrize(
"dims", [[], 256, 311, [264], [256, 512], [250, 500], [7, 5, 3], [2, 3, 5, 3]]
)
@pytest.mark.parametrize("block_scaling_dim", [1, 2])
@pytest.mark.parametrize("dq_columnwise", [True, False])
@pytest.mark.xfail(raises=NotImplementedError)
def test_quantize_dequantize_compact_format(
self, dims: DimsType, block_scaling_dim: int, dq_columnwise: bool
) -> None:
atol = _tols[tex.DType.kFloat8E4M3]["atol"]
......@@ -186,6 +219,7 @@ class TestFloat8BlockwiseTensor:
rowwise=True,
columnwise=dq_columnwise,
block_scaling_dim=block_scaling_dim,
all_gather_usage=True,
)
self._test_quantize_dequantize(
quantizer=quantizer,
......@@ -250,8 +284,13 @@ class TestFloat8BlockwiseTensor:
@pytest.mark.parametrize("dims", [[256, 512], [250, 500]])
@pytest.mark.parametrize("block_scaling_dim", [1, 2])
def test_serialization(self, dims: DimsType, block_scaling_dim: int) -> None:
@pytest.mark.parametrize("all_gather_usage", [True, False])
def test_serialization(
self, dims: DimsType, block_scaling_dim: int, all_gather_usage: bool
) -> None:
"""Test serialization of Float8BlockwiseQTensor"""
if all_gather_usage and block_scaling_dim != 1:
pytest.skip("all_gather_usage only implemented for 1D block quantization.")
device = "cuda"
dtype = torch.bfloat16
x_hp = torch.rand(_to_list(dims), dtype=dtype, device=device)
......@@ -260,6 +299,7 @@ class TestFloat8BlockwiseTensor:
rowwise=True,
columnwise=True,
block_scaling_dim=block_scaling_dim,
all_gather_usage=all_gather_usage,
)
# Create FP8 tensor
......@@ -283,6 +323,7 @@ class TestFloat8BlockwiseTensor:
assert x_fp8_loaded._is_2D_scaled == x_fp8._is_2D_scaled
assert x_fp8_loaded.dtype == x_fp8.dtype
assert x_fp8_loaded._fp8_dtype == x_fp8._fp8_dtype
assert x_fp8_loaded._data_format == x_fp8._data_format
# Test that dequantized values match
x_fp8_dequant = x_fp8.dequantize()
......
......@@ -252,11 +252,14 @@ struct QuantizationConfig {
bool force_pow_2_scales = false;
float amax_epsilon = 0.0f;
NVTETensor noop_tensor = nullptr;
Float8BlockScaleTensorFormat float8_block_scale_tensor_format =
Float8BlockScaleTensorFormat::GEMM_READY;
static constexpr size_t attr_sizes[] = {
sizeof(bool), // force_pow_2_scales
sizeof(float), // amax_epsilon
sizeof(NVTETensor) // noop_tensor
sizeof(NVTETensor), // noop_tensor
sizeof(Float8BlockScaleTensorFormat) // float8_block_scale_tensor_format
};
};
......
......@@ -302,6 +302,13 @@ enum NVTEQuantizationConfigAttribute {
conditional early even when captured in a static CUDA graph.
*/
kNVTEQuantizationConfigNoopTensor = 2,
/*! Data format for an FP8 block-scaled tensor
*
* This is not the right design since the tensor format is a
* property of the tensor, not the quantization. This enum will
* likely be refactored away in the future.
*/
kNVTEQuantizationConfigFloat8BlockScaleTensorFormat = 3,
kNVTEQuantizationConfigNumAttributes
};
......@@ -721,6 +728,16 @@ class TensorWrapper {
NVTETensor tensor_ = nullptr;
};
/*! \enum Float8BlockScaleTensorFormat
* \brief Data format for an FP8 block-scaled tensor
*/
enum class Float8BlockScaleTensorFormat {
/*! FP8 data is transposed if needed and scales are swizzled */
GEMM_READY = 0,
/*! FP8 data is untransposed and scales are not swizzled or padded */
COMPACT = 1
};
/*! \struct QuantizationConfigWrapper
* \brief C++ wrapper for NVTEQuantizationConfigWrapper.
*/
......@@ -774,6 +791,13 @@ class QuantizationConfigWrapper {
sizeof(NVTETensor));
}
/*! \brief Set FP8 block-scaled tensor format */
void set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat format) {
nvte_set_quantization_config_attribute(config_,
kNVTEQuantizationConfigFloat8BlockScaleTensorFormat,
&format, sizeof(Float8BlockScaleTensorFormat));
}
private:
/*! \brief Wrapped NVTEQuantizationConfig. */
NVTEQuantizationConfig config_ = nullptr;
......
......@@ -562,6 +562,9 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config,
case kNVTEQuantizationConfigNoopTensor:
std::memcpy(buf, &config_.noop_tensor, attr_size);
break;
case kNVTEQuantizationConfigFloat8BlockScaleTensorFormat:
std::memcpy(buf, &config_.float8_block_scale_tensor_format, attr_size);
break;
default:
NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
}
......@@ -594,6 +597,9 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config,
case kNVTEQuantizationConfigNoopTensor:
std::memcpy(&config_.noop_tensor, buf, attr_size);
break;
case kNVTEQuantizationConfigFloat8BlockScaleTensorFormat:
std::memcpy(&config_.float8_block_scale_tensor_format, buf, attr_size);
break;
default:
NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
}
......
......@@ -31,25 +31,27 @@ void quantize_transpose_square_blockwise(const SimpleTensor &input, SimpleTensor
// enum class for rowwise usage
enum class FP8BlockwiseRowwiseOption {
// No rowwise data
// No rowwise data, skip rowwise quantization
NONE,
// Rowwise data, scales in GEMM format
ROWWISE
// TODO: FP8 all gather requires some changes.
// 1. Compact scales are better for gathering than the GEMM format.
ROWWISE_GEMM_READY,
// Rowwise data, scales in compact format, needs extra processing (padding, transposing) before GEMM
ROWWISE_COMPACT
};
// enum class for columnwise usage
// For Hopper sm90 with only TN fp8 gemm, there is need to do columnwise transpose when doing 1D block scaling
enum class FP8BlockwiseColumnwiseOption {
// No columnwise data
// No columnwise data, skip columnwise quantization
NONE,
// Columnwise data transposed from original shape.
// Scales in GEMM format corresponding to GEMM ingesting transposed column data.
COLUMNWISE_TRANSPOSE
// TODO: FP8 all gather requires some changes.
// 1. The transpose gets in the way of the all gather.
// 2. Compact scales are better for gathering than the GEMM format.
// On Hopper sm90, GEMM_READY means that columnwise quantization also fuses transpose op
// On higher sm versions with TN,NT,NN fp8 gemm, GEMM_READY doesn't fuse transpose
COLUMNWISE_GEMM_READY,
// Columnwise data in original shape
// Scales in compact format, needs extra processing (padding, transposing) before GEMM
COLUMNWISE_COMPACT
};
void quantize_transpose_vector_blockwise(const SimpleTensor &input, SimpleTensor &scale_inv,
......
......@@ -99,14 +99,14 @@ Step 2: Cast and store to output_c
| ... |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
Step 3: Transpose, cast and store to output_t
Step 3 (if columnwise transpose is True, GEMM_READY): Transpose, cast and store to output_t
* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding)
* 8 warps
* Loop 2 times
* What each thread does in each loop:
* 2 elements (in a row) are read from the shared memory at a time, for a total of 16 times
* Every 8 consecutive threads do reduction and calculate the amax of each column
* 16 elements are quantized and write to output_c at a time, for a total of 2 times
* 16 elements are quantized and write to output_t at a time, for a total of 2 times
+------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+
| T0 | T8 | T16 | T24 | | | | T0 | T8 | T16 | T24 | | | |
| T1 | T9 | T17 | T25 | | | | T1 | T9 | T17 | T25 | | | |
......@@ -118,6 +118,29 @@ Step 3: Transpose, cast and store to output_t
| T7 | T15 | T23 | T31 | | | | T7 | T15 | T23 | T31 | | | |
+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+
Step 3 (if columnwise transpose is False, COMPACT format): Skip Transpose, cast and store to output_t
* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding)
* 8 warps
* Loop 1 times
* What each thread does in each loop:
* 16 elements (in a row) are read from the shared memory, for a total of 4 rows,
* it needs 8 reads in smem to get 16 elements in a row, thread tile shape is 16x4
* Every 32 consecutive threads in a warp do reduction and calculate the amax of each column,
* so each thread will do warp shuffle 16 times to get the amax of each column
* 16 elements are quantized and write to output_t at a time, for a total of 4 times
+------16 elements-------+------16 elements-------+-----80 elements-----+------16 elements------+
| T0 | | | |
| T1 | | | |
| T2 | | | |
| T3 | | | |
| T4 | | | |
| T5 | | | |
| T6 | | | |
| T7 | | | |
| ... | | | |
| T31 | | | |
+-----------------------+-----------------------+-----------------------+-----------------------+
*/
// clang-format on
......@@ -140,6 +163,7 @@ constexpr int kNumThreadsLoad = kTileDim / kNVecIn;
constexpr int kNumThreadsStore = kTileDim / kNVecOut;
static_assert(kNumThreadsLoad <= kThreadsPerWarp, "kNumThreadsLoad must be <= kThreadsPerWarp");
static_assert(kNumThreadsStore <= kThreadsPerWarp, "kNumThreadsStore must be <= kThreadsPerWarp");
constexpr int kNumWarps = kThreadsPerBlock / kThreadsPerWarp;
template <bool kAligned, typename CType, typename IType, typename OType>
__global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpose_kernel(
......@@ -149,9 +173,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon,
FP8BlockwiseRowwiseOption rowwise_option, FP8BlockwiseColumnwiseOption columnwise_option,
const bool pow_2_scaling) {
bool return_rowwise = rowwise_option == FP8BlockwiseRowwiseOption::ROWWISE;
bool return_columnwise_transpose =
columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_TRANSPOSE;
bool return_rowwise = rowwise_option != FP8BlockwiseRowwiseOption::NONE;
bool return_columnwise_gemm_ready =
columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY;
bool return_columnwise_compact =
columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT;
using SMemVec = Vec<IType, kNVecSMem>;
using OVec = Vec<OType, kNVecOut>;
......@@ -299,8 +325,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
}
}
// Step 3: Transpose, cast and store to output_t
if (return_columnwise_transpose) {
// Step 3 (return_columnwise_gemm_ready): Transpose, cast and store to output_t
if (return_columnwise_gemm_ready) {
constexpr int c_stride =
kThreadsPerBlock / kNumThreadsStore; // Stride in columns of shared memory
constexpr int num_iterations = kTileDim / (c_stride * kNVecSMem);
......@@ -385,6 +411,103 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
}
}
}
// Step 4 (return_columnwise_compact): cast in 128x1 style and store to output, skip transpose
if (return_columnwise_compact) {
// thread tile should be 4x16, 16 means 8 smem reads
constexpr int kThreadTileRow = kTileDim / kThreadsPerWarp;
constexpr int kThreadTileCol = kNVecOut;
using RegVec = Vec<IType, kThreadTileCol>;
using RegScaleVec = Vec<CType, kThreadTileCol>;
constexpr int num_smem_reads = kNVecOut / kNVecSMem;
// c_stride will not be used here because we only have one iteration
// constexpr int c_stride = kThreadTileCol * kNumWarps / kNVecSMem;
constexpr int num_iterations =
kTileDim / (kNumWarps * kThreadTileCol); // should be only one iteration
static_assert(num_iterations == 1,
"num_iterations should be 1 for columnwise non-transpose case");
const int thr_idx_in_warp = threadIdx.x % kThreadsPerWarp;
const int warp_idx = threadIdx.x / kThreadsPerWarp;
const int r_s = thr_idx_in_warp * kThreadTileRow; // Row in shared memory
int c_s = warp_idx * num_smem_reads; // Column in shared memory
size_t r_g = static_cast<size_t>(blockIdx.y) * kTileDim + r_s; // Row in global memory
const size_t c_g =
static_cast<size_t>(blockIdx.x) * kTileDim + c_s * kNVecSMem; // Column in global memory
const size_t num_ele = c_g < row_length
? min(static_cast<size_t>(kThreadTileCol), row_length - c_g)
: 0; // For not aligned case
#pragma unroll
for (int iter = 0; iter < num_iterations; ++iter) {
RegVec reg_vec[kThreadTileRow];
RegScaleVec thr_scale;
// Step 3.1: Load from shared memory to registers
#pragma unroll
for (int i = 0; i < kThreadTileRow; ++i) {
int r = r_s + i;
#pragma unroll
for (int j = 0; j < num_smem_reads; ++j) {
int c = c_s + j;
SMemVec smem_vec = smem[r * kSMemCol + c];
// copy smem_vec to reg vec with its elements
#pragma unroll
for (int k = 0; k < kNVecSMem; ++k) {
reg_vec[i].data.elt[j * kNVecSMem + k] = smem_vec.data.elt[k];
}
}
}
#pragma unroll
for (int reg_idx = 0; reg_idx < kThreadTileCol; ++reg_idx) {
// Step 3.2: Compute local amax
CType amax = 0;
#pragma unroll
for (int i = 0; i < kThreadTileRow; ++i) {
amax = fmaxf(amax, fabsf(reg_vec[i].data.elt[reg_idx]));
}
// Step 3.3: Reduce amax
const bool is_src_lane = thr_idx_in_warp == 0;
amax = warp_reduce_max<kThreadsPerWarp>(amax);
constexpr int lane_zero = 0;
amax = __shfl_sync(0xFFFFFFFF, amax, lane_zero);
// Step 3.4: Compute scale
CType scale;
scale = compute_scale_from_types<IType, OType>(amax, epsilon, pow_2_scaling);
thr_scale.data.elt[reg_idx] = scale;
// Step 3.5: Write scale_inv_t
bool write_scale_inv = is_src_lane;
if constexpr (!kAligned) {
write_scale_inv &= (c_g + reg_idx < row_length);
}
if (write_scale_inv) {
CType scale_inv = 1.0 / scale;
size_t row_idx = static_cast<size_t>(blockIdx.y);
size_t col_idx = static_cast<size_t>(blockIdx.x) * kTileDim + c_s * kNVecSMem + reg_idx;
tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv;
}
}
// Step 3.6: Quantize
for (int row_idx = 0; row_idx < kThreadTileRow; ++row_idx) {
OType* output_g =
&output_t[(r_g + row_idx) * row_length + c_g]; // Output address in global memory
OVec output_vec;
#pragma unroll
for (int i = 0; i < kThreadTileCol; ++i) {
output_vec.data.elt[i] = static_cast<OType>(
static_cast<CType>(reg_vec[row_idx].data.elt[i]) * thr_scale.data.elt[i]);
}
// Step 3.7: Store output_t
if constexpr (kAligned) {
output_vec.store_to(output_g);
} else {
if (r_g + row_idx < num_rows) {
output_vec.store_to_elts(output_g, 0, num_ele);
}
}
}
// Step 3.8: Update output address, column index of shared memory
// this section shouldn't matter since we only have one iteration
}
}
}
} // namespace
......@@ -400,11 +523,6 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
const bool pow2_scale, cudaStream_t stream) {
NVTE_API_CALL(quantize_transpose_vector_blockwise);
// assert that rowwise_option and columnwise_option are not both NONE
NVTE_CHECK(rowwise_option != FP8BlockwiseRowwiseOption::NONE ||
columnwise_option != FP8BlockwiseColumnwiseOption::NONE,
"rowwise_option and columnwise_option cannot both be NONE");
const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u;
size_t num_elements = row_length;
size_t num_rows = 1;
......@@ -425,32 +543,43 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
size_t scale_t_stride_y = 0;
if (rowwise_option != FP8BlockwiseRowwiseOption::NONE) {
NVTE_CHECK(rowwise_option == FP8BlockwiseRowwiseOption::ROWWISE,
NVTE_CHECK(rowwise_option == FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY ||
rowwise_option == FP8BlockwiseRowwiseOption::ROWWISE_COMPACT,
"Unexpected rowwise enum value");
NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape.");
NVTE_CHECK(scale_inv.shape.size() == 2, "Scale dimension must be 2.");
size_t scale_k = scale_inv.shape[1];
scale_stride_x = scale_k;
scale_stride_y = 1;
bool rowwise_compact = rowwise_option == FP8BlockwiseRowwiseOption::ROWWISE_COMPACT;
scale_stride_x = rowwise_compact ? 1 : scale_k;
scale_stride_y = rowwise_compact ? scale_k : 1;
}
if (columnwise_option != FP8BlockwiseColumnwiseOption::NONE) {
NVTE_CHECK(columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_TRANSPOSE,
"Unexpected columnwise enum value");
NVTE_CHECK(output_t.shape.size() == input.shape.size(),
"output_t must have same number of dimensions as input.");
if (output_t.shape.size() > 0) {
if (columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY) {
NVTE_CHECK(output_t.shape[0] == row_length, "Wrong dimension 0 of output_t.");
for (size_t i = 1; i < output_t.shape.size(); ++i) {
NVTE_CHECK(output_t.shape.at(i) == input.shape.at(i - 1), "Wrong dimension in output_t");
}
} else {
NVTE_CHECK(columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT,
"Unexpected columnwise option enum value");
NVTE_CHECK(output_t.shape[0] == input.shape[0], "Wrong dimension 0 of output_t.");
NVTE_CHECK(
input.shape == output_t.shape,
"Input and output_t must have the same shape for columnwise non-transpose case.");
}
}
NVTE_CHECK(output.dtype == output_t.dtype, "output and output_t need to have the same dtype.");
NVTE_CHECK(scale_inv_t.shape.size() == 2, "Scale_t dimension must be 2.");
scale_t_stride_x = scale_inv_t.shape[1];
scale_t_stride_y = 1;
bool columnwise_compact = columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT;
size_t scale_t_k = scale_inv_t.shape[1];
scale_t_stride_x = columnwise_compact ? 1 : scale_t_k;
scale_t_stride_y = columnwise_compact ? scale_t_k : 1;
}
const size_t num_blocks_x = DIVUP(row_length, (size_t)kTileDim);
......
......@@ -1283,12 +1283,25 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
"IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D");
bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : false;
float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f;
FP8BlockwiseRowwiseOption rowwise_option = output_tensor->has_data()
? FP8BlockwiseRowwiseOption::ROWWISE
: FP8BlockwiseRowwiseOption::NONE;
FP8BlockwiseColumnwiseOption columnwise_option =
output_tensor->has_columnwise_data() ? FP8BlockwiseColumnwiseOption::COLUMNWISE_TRANSPOSE
: FP8BlockwiseColumnwiseOption::NONE;
FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE;
FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE;
if (output_tensor->has_data()) {
bool rowwise_compact = quant_config_cpp
? quant_config_cpp->float8_block_scale_tensor_format ==
Float8BlockScaleTensorFormat::COMPACT
: false;
rowwise_option = rowwise_compact ? FP8BlockwiseRowwiseOption::ROWWISE_COMPACT
: FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY;
}
if (output_tensor->has_columnwise_data()) {
bool columnwise_compact = quant_config_cpp
? quant_config_cpp->float8_block_scale_tensor_format ==
Float8BlockScaleTensorFormat::COMPACT
: false;
columnwise_option = columnwise_compact
? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT
: FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY;
}
quantize_transpose_vector_blockwise(input_tensor->data, output_tensor->scale_inv,
output_tensor->columnwise_scale_inv, output_tensor->data,
output_tensor->columnwise_data, epsilon, rowwise_option,
......
......@@ -75,6 +75,10 @@
.value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \
.value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) \
.value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); \
pybind11::enum_<transformer_engine::Float8BlockScaleTensorFormat>( \
m, "Float8BlockScaleTensorFormat", pybind11::module_local()) \
.value("GEMM_READY", transformer_engine::Float8BlockScaleTensorFormat::GEMM_READY) \
.value("COMPACT", transformer_engine::Float8BlockScaleTensorFormat::COMPACT); \
pybind11::enum_<transformer_engine::CommOverlapType>(m, "CommOverlapType", \
pybind11::module_local()) \
.value("RS", transformer_engine::CommOverlapType::RS) \
......
......@@ -77,6 +77,14 @@ def general_gemm(
# There is not use_split_accumulator == False
# implementation for Float8BlockwiseQTensorBase GEMM
use_split_accumulator = True
# Check that data format is supported
if (
A._data_format != tex.Float8BlockScaleTensorFormat.GEMM_READY
or B._data_format != tex.Float8BlockScaleTensorFormat.GEMM_READY
):
raise RuntimeError("GEMM with Float8BlockwiseQTensor requires GEMM_READY format")
args = (
A,
transa, # transa
......
......@@ -173,6 +173,8 @@ class Float8BlockQuantizer : public Quantizer {
bool force_pow_2_scales = false;
// Amax within quantization tile has a floor of epsilon.
float amax_epsilon = 0.0;
// Whether quantized tensor will be used in an all-gather
bool all_gather_usage = false;
private:
int block_scaling_dim = 2;
......
......@@ -81,6 +81,9 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob
auto my_quantizer_bw = static_cast<Float8BlockQuantizer*>(my_quantizer.get());
quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon);
if (my_quantizer_bw->all_gather_usage) {
quant_config.set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat::COMPACT);
}
}
NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_v2(te_input.data(), te_output.data(), quant_config,
......
......@@ -170,6 +170,9 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
auto my_quantizer_bw = static_cast<Float8BlockQuantizer *>(my_quantizer.get());
quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon);
if (my_quantizer_bw->all_gather_usage) {
quant_config.set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat::COMPACT);
}
}
NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config,
......@@ -328,6 +331,9 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
auto my_quantizer_bw = static_cast<Float8BlockQuantizer *>(my_quantizer.get());
quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon);
if (my_quantizer_bw->all_gather_usage) {
quant_config.set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat::COMPACT);
}
}
NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config,
......
......@@ -261,6 +261,7 @@ Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quanti
this->amax_epsilon = quantizer.attr("amax_epsilon").cast<float>();
NVTE_CHECK(this->block_scaling_dim == 1 || this->block_scaling_dim == 2,
"Unsupported block scaling dim.");
this->all_gather_usage = quantizer.attr("all_gather_usage").cast<bool>();
}
void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const {
......@@ -299,6 +300,10 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
size_t m_dim = numel / k_dim;
constexpr size_t kBlockLen = 128;
Float8BlockScaleTensorFormat data_format =
(all_gather_usage ? Float8BlockScaleTensorFormat::COMPACT
: Float8BlockScaleTensorFormat::GEMM_READY);
if (rowwise_usage) {
if (rowwise_data.has_value()) {
data_rowwise = std::move(*rowwise_data);
......@@ -308,14 +313,24 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
size_t sinv0 = 0;
size_t sinv1 = 0;
if (block_scaling_dim == 2) {
// 2D scaling is always GEMM_READY for now
NVTE_CHECK(data_format == Float8BlockScaleTensorFormat::GEMM_READY,
"2D scaling is always GEMM_READY for now.");
sinv0 = (m_dim + kBlockLen - 1) / kBlockLen;
sinv1 = roundup((k_dim + kBlockLen - 1) / kBlockLen, 4);
} else if (block_scaling_dim == 1) {
// 1D scaling can be GEMM_READY or COMPACT
bool rowwise_compact = data_format == Float8BlockScaleTensorFormat::COMPACT;
// default rowwise scaling factor shape already transpose the scaling factor so it's GEMM_READY
sinv0 = (k_dim + kBlockLen - 1) / kBlockLen;
sinv1 = roundup(m_dim, 4);
sinv1 = rowwise_compact ? m_dim : roundup(m_dim, 4);
// if the rowwise format is compact, the scaling factor is not be transposed
if (rowwise_compact) {
std::swap(sinv0, sinv1);
}
} else {
NVTE_CHECK(false,
"Unsupported block_scaling_dim in create_tensor rowwise."
NVTE_ERROR(
"Unsupported block_scaling_dim in create_tensor rowwise. "
"Expected 1 or 2. Got ",
block_scaling_dim);
}
......@@ -332,6 +347,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
NVTE_CHECK(torch_shape.size() == shape.size(), "Shape expected to match torch shape. Shape ",
columnwise_shape, " torch shape: ", torch_columnwise_shape);
if (torch_shape.size() > 0) {
if (!all_gather_usage) {
torch_columnwise_shape.reserve(torch_shape.size());
columnwise_shape.reserve(shape.size());
torch_columnwise_shape.push_back(torch_shape[torch_shape.size() - 1]);
......@@ -340,18 +356,32 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
torch_columnwise_shape.push_back(torch_shape[i]);
columnwise_shape.push_back(shape[i]);
}
} else {
// assert we are doing 1D scaling
NVTE_CHECK(block_scaling_dim == 1,
"Compact columnwise format is not supported for 128x128 2D block scaling.");
torch_columnwise_shape = torch_shape;
columnwise_shape = shape;
}
}
size_t sinv0 = 0;
size_t sinv1 = 0;
if (block_scaling_dim == 2) {
// 2D scaling is always GEMM_READY for now
NVTE_CHECK(data_format == Float8BlockScaleTensorFormat::GEMM_READY,
"2D scaling is always GEMM_READY for now.");
sinv0 = (k_dim + kBlockLen - 1) / kBlockLen;
sinv1 = roundup((m_dim + kBlockLen - 1) / kBlockLen, 4);
} else if (block_scaling_dim == 1) {
bool columnwise_compact = data_format == Float8BlockScaleTensorFormat::COMPACT;
sinv0 = (m_dim + kBlockLen - 1) / kBlockLen;
sinv1 = roundup(k_dim, 4);
sinv1 = columnwise_compact ? k_dim : roundup(k_dim, 4);
// GEMM READY case: scaling factor is [sinv0, sinv1], already transposed here for CuBLAS
// for COMPACT case, since we apply 128x1 scaling here without transposing columnwise data, scaling factor is also [sinv0, sinv1]
// so no need to swap sinv0 and sinv1 here
} else {
NVTE_CHECK(false,
"Unsupported block_scaling_dim in create_tensor columnwise."
NVTE_ERROR(
"Unsupported block_scaling_dim in create_tensor columnwise. "
"Expected 1 or 2. Got ",
block_scaling_dim);
}
......@@ -373,7 +403,7 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
"rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise,
"rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_scale_inv"_a = scale_inv_colwise,
"fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer,
"is_2D_scaled"_a = (block_scaling_dim == 2));
"is_2D_scaled"_a = (block_scaling_dim == 2), "data_format"_a = data_format);
} else {
py::handle Float8BlockwiseQTensorClass(
reinterpret_cast<PyObject*>(Float8BlockwiseQTensorPythonClass));
......@@ -381,7 +411,8 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
"shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), "rowwise_data"_a = data_rowwise,
"columnwise_data"_a = data_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise,
"columnwise_scale_inv"_a = scale_inv_colwise, "fp8_dtype"_a = this->dtype,
"quantizer"_a = this->quantizer, "is_2D_scaled"_a = (block_scaling_dim == 2));
"quantizer"_a = this->quantizer, "is_2D_scaled"_a = (block_scaling_dim == 2),
"data_format"_a = data_format);
}
return {std::move(tensor), std::move(ret)};
......
......@@ -8,6 +8,7 @@ from __future__ import annotations
from collections.abc import Iterable
from contextlib import contextmanager, AbstractContextManager, ContextDecorator
from functools import lru_cache
from dataclasses import dataclass
import math
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import warnings
......@@ -19,6 +20,15 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._common_utils import _get_module_fsdp_state
from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_modules
try:
import torch.distributed._symmetric_memory as symm_mem
HAS_TORCH_SYMMETRIC = True
except ImportError:
HAS_TORCH_SYMMETRIC = False
import transformer_engine_torch as tex
from . import torch_version
from .utils import (
is_non_tn_fp8_gemm_supported,
......@@ -34,14 +44,8 @@ from .tensor.quantized_tensor import QuantizedTensor, Quantizer
from .tensor._internal.float8_tensor_base import Float8TensorBase
from .tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from .tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor, DebugQuantizer
try:
import torch.distributed._symmetric_memory as symm_mem
HAS_TORCH_SYMMETRIC = True
except ImportError:
HAS_TORCH_SYMMETRIC = False
__all__ = ["checkpoint", "CudaRNGStatesTracker"]
......@@ -977,6 +981,67 @@ def _all_gather_fp8(
return out, handle
def _set_quantizer_format(quantizer: Quantizer, compact: bool = False) -> None:
"""Make quantizer compact"""
_quantizer = quantizer
if isinstance(quantizer, DebugQuantizer):
_quantizer = quantizer.parent_quantizer
if isinstance(_quantizer, Float8BlockQuantizer):
_quantizer.all_gather_usage = compact
def _post_process_fp8_blockwise_gather(
out: Float8BlockwiseQTensorBase,
quantizer: Float8BlockQuantizer,
handle: Optional[torch.distributed.Work] = None,
) -> Float8BlockwiseQTensorBase:
"""Post-process FP8 blockwise gather."""
if handle is not None:
handle.wait()
handle = None
if out._is_gemm_ready_format():
return out
needs_columnwise_data_transpose = (
quantizer is not None and quantizer.columnwise_usage and not is_non_tn_fp8_gemm_supported()
)
need_rowwise_scale_transpose = (
quantizer is not None and quantizer.rowwise_usage and not is_non_tn_fp8_gemm_supported()
)
# CuBLAS requires transpose of the scale inv tensor, suppose orig input is 256x1024
# columnwise compact format means doing 128x1 quantization of it
# so quantized tensor is 256x1024, scale inv is 2x1024
# If we were doing GEMM_READY format, then it's equivalent to do 1x128 quantization
# on a transposed 1024x256 tensor, so scale inv is 1024x2, cublas requries 2x1024
# Thereforce, it turns out we don't need to transpose the scale inv, only columnwise data
if needs_columnwise_data_transpose:
out._transpose_columnwise_data()
if need_rowwise_scale_transpose:
out._rowwise_scale_inv = out._rowwise_scale_inv.transpose(-2, -1).contiguous()
out._data_format = tex.Float8BlockScaleTensorFormat.GEMM_READY
return out
@dataclass
class _FP8BlockwiseAllGatherAsyncHandle:
"""Handle for asynchronous FP8 blockwise all-gather."""
tensor: Float8BlockwiseQTensorBase
quantizer: Float8BlockQuantizer
async_handle: torch.distributed.Work
_synchronized: bool = False
def wait(self) -> None:
"""Wait for the async operation to complete and post-process the tensor."""
if self._synchronized:
return
self.async_handle.wait()
_post_process_fp8_blockwise_gather(self.tensor, self.quantizer)
self._synchronized = True
def _all_gather_fp8_blockwise(
inp: torch.Tensor,
process_group: dist_group_type,
......@@ -990,8 +1055,9 @@ def _all_gather_fp8_blockwise(
Returns: quantizer(gather(inp))
NOTE: The implementation is not sophisticated enough to honor async_op=True.
In some cases it falls back to synchronous gather and invokes the quantizer.
NOTE: The implementation is only going to honor async_op=True for FP8 gather case.
In the case where tensor shape is not divisible by 128, the implementation will fall back
to synchronous gather and invoke the quantizer.
"""
# Input tensor attributes
......@@ -1027,7 +1093,11 @@ def _all_gather_fp8_blockwise(
out_shape[0] *= world_size
# Doing BF16 gather for now as baseline because it's simpler
if not isinstance(inp, Float8BlockwiseQTensorBase) and quantizer is not None:
if (
not isinstance(inp, Float8BlockwiseQTensorBase)
and quantizer is not None
and not quantizer.is_quantizable(inp)
):
out = torch.empty(
out_shape,
dtype=dtype,
......@@ -1035,14 +1105,93 @@ def _all_gather_fp8_blockwise(
memory_format=torch.contiguous_format,
)
torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False)
orig_all_gather_usage = quantizer.all_gather_usage
quantizer.all_gather_usage = False
out = quantizer(out)
quantizer.all_gather_usage = orig_all_gather_usage
return out, None
# Implementation of fp8 gather needs to account for:
# * Getting columnwise data as a transpose of how it is stored for GEMMS.
# * Gathering non GEMM swizzled scales.
# * Refer to scaffold code when implementing at:
# https://github.com/kwyss-nvidia/TransformerEngine/commit/6659ee9dc84fb515d1d47699d8bfd20a72b76477
raise NotImplementedError("fp8 blockwise allgather not yet implemented")
# Cast input tensor to Float8BlockwiseQTensor with required data
# Set to compact usage in case the quantizer is not correctly configured
orig_all_gather_usage = quantizer.all_gather_usage
quantizer.all_gather_usage = True
if not isinstance(inp, Float8BlockwiseQTensorBase):
inp = quantizer(inp)
elif (quantizer.rowwise_usage and inp._rowwise_data is None) or (
quantizer.columnwise_usage and inp._columnwise_data is None
):
warnings.warn(
"Input and quantizer do not have matching usages. "
"Dequantizing and requantizing to Float8BlockwiseQTensor."
)
inp = quantizer(inp.dequantize())
quantizer.all_gather_usage = orig_all_gather_usage
# Begin to do network communication, need to make sure compact format
if inp._data_format != tex.Float8BlockScaleTensorFormat.COMPACT:
raise RuntimeError(
"All-gather with FP8 block-wise quantized tensor requires compact data format, "
f"but found data_format={inp._data_format}"
)
# Construct Float8BlockwiseQTensor output tensor
out = quantizer.make_empty(out_shape, dtype=dtype, device=device)
# Coalesce NCCL collectives
with torch.distributed._coalescing_manager(
group=process_group,
device=device,
async_ops=async_op,
) as coalescing_manager:
# Gather Float8BlockwiseQTensor data for row-wise usage
if quantizer.rowwise_usage:
# Launch all-gathers
torch.distributed.all_gather_into_tensor(
out._rowwise_scale_inv,
inp._rowwise_scale_inv,
group=process_group,
)
torch.distributed.all_gather_into_tensor(
out._rowwise_data,
inp._rowwise_data,
group=process_group,
)
# Gather Float8BlockwiseQTensor data for column-wise usage
if quantizer.columnwise_usage:
# Launch all-gathers
torch.distributed.all_gather_into_tensor(
out._columnwise_scale_inv,
inp._columnwise_scale_inv,
group=process_group,
)
torch.distributed.all_gather_into_tensor(
out._columnwise_data,
inp._columnwise_data,
group=process_group,
)
handle = coalescing_manager if async_op else None
# Unlike MXFP8, this fp8 blockwise tensor primarily works with Hopper
# This means that we need to transpose the gathered columnwise data
# Example usage is grad_output tensor, ie. dY in linear backward
# We want to gather two FP8 tensors (rowwise and columnwise) along dim0
# and then transpose the columnwise data to match the rowwise data
# Make sure FP8 transpose is populated if needed
if async_op:
handle = _FP8BlockwiseAllGatherAsyncHandle(out, quantizer, handle)
else:
# if it's a sync op, we need to do the transpose here as post processing step
_post_process_fp8_blockwise_gather(out, quantizer, handle)
return out, handle
def _all_gather_mxfp8(
......@@ -1267,6 +1416,9 @@ def gather_along_first_dim(
)
if isinstance(inp, QuantizedTensor):
inp = inp.dequantize()
# Falling back to high-precision all-gather for Float8BlockQuantizer
# means that it should directly output GEMM_READY format
_set_quantizer_format(quantizer, compact=False)
out = torch.empty(
out_shape,
dtype=inp.dtype,
......
......@@ -63,10 +63,11 @@ from ..tensor.quantized_tensor import (
)
from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.utils import any_feature_enabled
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ..cpp_extensions import (
......@@ -183,12 +184,6 @@ class _LayerNormLinear(torch.autograd.Function):
# All-gather is not supported with FP8 column-wise data
input_quantizer.set_usage(columnwise=False)
# Do TP communication in high precision if quantized format
# does not support communication
force_hp_blockwise_ln_out_gather = (
fp8 and with_input_all_gather and isinstance(input_quantizer, Float8BlockQuantizer)
)
# Avoid quantized norm kernel if norm output will be returned
# or if a gather of ln_out must be in high precision.
with_quantized_norm = (
......@@ -196,7 +191,6 @@ class _LayerNormLinear(torch.autograd.Function):
and not debug
and not return_layernorm_output
and not return_layernorm_output_gathered
and not force_hp_blockwise_ln_out_gather
)
# Apply normalization
......@@ -233,15 +227,16 @@ class _LayerNormLinear(torch.autograd.Function):
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
ln_out_return = ln_out_total
if fp8 or debug:
if not force_hp_blockwise_ln_out_gather:
ln_out = input_quantizer(ln_out)
input_quantizer.set_usage(rowwise=True, columnwise=False)
if isinstance(input_quantizer, Float8BlockQuantizer):
input_quantizer.all_gather_usage = False
ln_out_total = input_quantizer(ln_out_total)
else:
quantizer = None
if fp8 or debug:
quantizer = input_quantizer
if not with_quantized_norm and not force_hp_blockwise_ln_out_gather:
if not with_quantized_norm:
ln_out = quantizer(ln_out)
quantizer.set_usage(rowwise=True, columnwise=False)
if ub_overlap_ag_fprop: # Initialize Userbuffers all-gather
......@@ -391,7 +386,6 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.ln_out_needs_gather = (
weight.requires_grad and parallel_mode == "column" and sequence_parallel
)
ctx.force_hp_blockwise_ln_out_gather = force_hp_blockwise_ln_out_gather
# Input with column-wise usage is needed for wgrad GEMM.
if backward_needs_input:
......@@ -399,7 +393,10 @@ class _LayerNormLinear(torch.autograd.Function):
# For sequence parallel in vanilla FP8, rowwise data is
# to gather the input. For MXFP8, columnwise only data
# can be allgathered.
if isinstance(ln_out, MXFP8TensorBase) or not ctx.ln_out_needs_gather:
if (
isinstance(ln_out, (MXFP8TensorBase, Float8BlockwiseQTensorBase))
or not ctx.ln_out_needs_gather
):
ln_out.update_usage(rowwise_usage=False)
# Weight with column-wise usage is needed for dgrad GEMM.
......@@ -496,8 +493,8 @@ class _LayerNormLinear(torch.autograd.Function):
if return_layernorm_output:
if return_layernorm_output_gathered:
shape = list(inp_shape)
shape[0] *= tp_size
shape = list(inp.shape)
shape[0] *= tp_size if with_input_all_gather else 1
return out, ln_out_return.view(shape)
return out, ln_out_return.view(inp_shape)
return out
......@@ -631,7 +628,7 @@ class _LayerNormLinear(torch.autograd.Function):
ln_out_total_work = None
if ctx.ln_out_needs_gather:
quantizer = None
if ctx.input_quantizer is not None and not ctx.force_hp_blockwise_ln_out_gather:
if ctx.input_quantizer is not None:
quantizer = ctx.input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
# If data is in FP8, we compute FP8 transposes manually
......@@ -1376,6 +1373,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling():
self._customize_quantizers_float8_current_scaling(fwd, recipe)
elif recipe.float8_block_scaling():
self._customize_quantizers_float8_blockwise_scaling(fwd, recipe)
# elif other recipes (mxfp8, etc)
def reset_layer_norm_parameters(self) -> None:
......@@ -1677,3 +1676,14 @@ class LayerNormLinear(TransformerEngineBaseModule):
weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
weight_quantizer.internal = True
return [weight_quantizer]
def _customize_quantizers_float8_blockwise_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on blockwise scaling recipe + layernorm_linear."""
assert (
recipe.float8_block_scaling()
), "blockwise scaling recipe quantizer customization here"
if fwd:
if self.sequence_parallel and self.parallel_mode == "column":
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].all_gather_usage = True
......@@ -243,26 +243,18 @@ class _LayerNormMLP(torch.autograd.Function):
# All-gather is not supported with FP8 column-wise data
fc1_input_quantizer.set_usage(columnwise=False)
# Do TP communication in high precision if quantized format
# does not support communication
force_hp_fc1_input_gather = (
fp8 and sequence_parallel and isinstance(fc1_input_quantizer, Float8BlockQuantizer)
)
# for fp8 DelayedScaling: layernorm output = FP8
# only output of the linear is returned
# for return_layernorm_output: layernorm output = High precision, then cast to FP8
# high precision layernorm output and output of the linear are returned
# for debug: : layernorm output = High precision to enable processing of this norm
with_quantized_norm = (
fp8
and not debug
and not return_layernorm_output
and not return_layernorm_output_gathered
and not debug
)
if isinstance(fc1_input_quantizer, Float8BlockQuantizer):
# Kernels not available for norm fusion.
with_quantized_norm = False
# Apply normalization
ln_out, mu, rsigma = apply_normalization(
......@@ -292,15 +284,16 @@ class _LayerNormMLP(torch.autograd.Function):
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
ln_out_return = ln_out_total
if fp8 or debug:
if not force_hp_fc1_input_gather:
ln_out = fc1_input_quantizer(ln_out)
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
if isinstance(fc1_input_quantizer, Float8BlockQuantizer):
fc1_input_quantizer.all_gather_usage = False
ln_out_total = fc1_input_quantizer(ln_out_total)
else:
quantizer = None
if fp8 or debug:
quantizer = fc1_input_quantizer
if not with_quantized_norm and not force_hp_fc1_input_gather:
if not with_quantized_norm:
ln_out = fc1_input_quantizer(ln_out)
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
if ub_overlap_ag:
......@@ -566,7 +559,6 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.tensor_objects = tensor_objects
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
ctx.force_hp_fc1_input_gather = force_hp_fc1_input_gather
ctx.fc1_grad_input_quantizer = fc1_grad_input_quantizer
ctx.fc1_grad_weight_quantizer = fc1_grad_weight_quantizer
ctx.fc1_grad_output_quantizer = fc1_grad_output_quantizer
......@@ -627,7 +619,7 @@ class _LayerNormMLP(torch.autograd.Function):
if return_layernorm_output:
if return_layernorm_output_gathered:
shape = list(inp_shape)
shape[0] *= tp_size
shape[0] *= tp_size if (sequence_parallel and set_parallel_mode) else 1
return fc2_out, ln_out_return.view(shape)
return fc2_out, ln_out_return.view(inp_shape)
return fc2_out
......@@ -742,7 +734,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub_obj_fc1_dgrad = None
if ctx.fc1_weight_requires_grad and ctx.tensor_parallel and ctx.sequence_parallel:
quantizer = None
if ctx.fp8 or ctx.debug and not ctx.force_hp_fc1_input_gather:
if ctx.fp8 or ctx.debug:
quantizer = ctx.fc1_input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
# If data is in FP8, we compute FP8 transposes manually
......@@ -1643,8 +1635,11 @@ class LayerNormMLP(TransformerEngineBaseModule):
super().set_meta_tensor(fwd, recipe)
# customize quantizers based on each recipe & layer configs
if FP8GlobalStateManager.get_fp8_recipe().float8_current_scaling():
recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling():
self._customize_quantizers_float8_current_scaling(fwd, recipe)
elif recipe.float8_block_scaling():
self._customize_quantizers_float8_blockwise_scaling(fwd, recipe)
# elif for other recipes (mxfp8, etc.)
def reset_layer_norm_parameters(self) -> None:
......@@ -1996,6 +1991,22 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_weight_quantizer.internal = True
return [fc1_weight_quantizer, fc2_weight_quantizer]
def _customize_quantizers_float8_blockwise_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on blockwise scaling recipe + layernorm_mlp."""
assert (
recipe.float8_block_scaling()
), "blockwise scaling recipe quantizer customization here"
if fwd:
if self.sequence_parallel and self.set_parallel_mode:
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].all_gather_usage = True
else:
if self.sequence_parallel and self.set_parallel_mode:
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT2
].all_gather_usage = True
def backward_dw(self):
"""
Execute the delayed weight gradient computation.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment