Unverified Commit 4742c0f8 authored by Kunlun Li's avatar Kunlun Li Committed by GitHub
Browse files

Enable fp8 primary weights for sub-channel recipe (#1641)



* Add fp8_primary_weights support for blockwise scaling
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

custom fsdp
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



Add view to blockwise fp8 tensor
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* Fix columnwise_shape in backward of view()
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* Add comments to the unit of start_offset
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* Add test for view and reshape for blockwise fp8 tensor
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* Add implementation for self._columnwise_scale_inv is not existed
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Track down checks for _columnwise_data is None and adding checks for  _columnwise_invalid
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add assertion to check whether ._quantizer is None
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* rename partial_cast.cu -> fp8_block_scaling_partial_cast.cu
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* rename partial_cast kernel to fp8_block_scaling_partial_cast kernel
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add shfl_sync in partial cast kernel
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* Remove columnwise_invalid flag
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add comments about out-of-bounds write
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

---------
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 26db7f34
......@@ -16,6 +16,7 @@ import torch.distributed as dist
from transformer_engine.common.recipe import (
DelayedScaling,
Float8CurrentScaling,
Float8BlockScaling,
Format,
Recipe,
)
......@@ -26,6 +27,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.tensor.utils import replace_raw_data
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockwiseQTensor
def _get_raw_data(quantized_tensor):
......@@ -34,6 +36,14 @@ def _get_raw_data(quantized_tensor):
assert hasattr(quantized_tensor, "_data"), "Float8Tensor does not have _data attribute"
assert quantized_tensor._data.dtype == torch.uint8, "Float8Tensor _data must be uint8"
return quantized_tensor._data
elif isinstance(quantized_tensor, Float8BlockwiseQTensor):
assert hasattr(
quantized_tensor, "_rowwise_data"
), "Float8BlockwiseQTensor does not have _rowwise_data attribute"
assert (
quantized_tensor._rowwise_data.dtype == torch.uint8
), "Float8BlockwiseQTensor _rowwise_data must be uint8"
return quantized_tensor._rowwise_data
else:
raise ValueError(f"Unsupported quantized tensor type: {type(quantized_tensor)}")
......@@ -435,15 +445,15 @@ def _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group):
preserve_high_precision_init_val=True,
):
model_fp8 = nn.Sequential(
te.Linear(128, 256, **linear_kwargs),
te.Linear(256, 256 * 3, **linear_kwargs),
te.Linear(128, 256 + 16, **linear_kwargs),
te.Linear(256 + 16, 256 * 3, **linear_kwargs),
te.Linear(256 * 3, 128, **linear_kwargs),
)
# Create model with BF16 weights
model = nn.Sequential(
te.Linear(128, 256, **linear_kwargs),
te.Linear(256, 256 * 3, **linear_kwargs),
te.Linear(128, 256 + 16, **linear_kwargs),
te.Linear(256 + 16, 256 * 3, **linear_kwargs),
te.Linear(256 * 3, 128, **linear_kwargs),
)
......@@ -539,12 +549,13 @@ def _test_zero_1(dp_group):
def quantization_recipe(quantization) -> Recipe:
"""Quantization recipe setup"""
fp8_format = Format.HYBRID
if quantization == "fp8":
return DelayedScaling(
fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max"
)
return DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max")
elif quantization == "fp8_cs":
return Float8CurrentScaling()
return Float8CurrentScaling(fp8_format=fp8_format)
elif quantization == "fp8_block":
return Float8BlockScaling(fp8_format=fp8_format)
else:
raise ValueError(f"Unsupported quantization: {quantization}")
......@@ -568,15 +579,15 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group):
preserve_high_precision_init_val=True,
):
model_fp8 = nn.Sequential(
te.Linear(128, 256, **linear_kwargs),
te.Linear(256, 256 * 3, **linear_kwargs),
te.Linear(128, 256 + 16, **linear_kwargs),
te.Linear(256 + 16, 256 * 3, **linear_kwargs),
te.Linear(256 * 3, 128, **linear_kwargs),
)
# Create model with BF16 weights
model = nn.Sequential(
te.Linear(128, 256, **linear_kwargs),
te.Linear(256, 256 * 3, **linear_kwargs),
te.Linear(128, 256 + 16, **linear_kwargs),
te.Linear(256 + 16, 256 * 3, **linear_kwargs),
te.Linear(256 * 3, 128, **linear_kwargs),
)
......@@ -593,7 +604,7 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group):
optimizer_fp8 = MiniZero_1([w for w in model_fp8.parameters()], 10.0, dp_group)
optimizer = MiniZero_1([w for w in model.parameters()], 10.0, dp_group)
for _ in range(100):
for i in range(100):
for w_fp8, w in zip(model_fp8.parameters(), model.parameters()):
w_fp8.main_grad.zero_()
w.main_grad.zero_()
......@@ -654,7 +665,9 @@ def main(argv=None, namespace=None):
dist.init_process_group(**dist_init_kwargs)
parser = argparse.ArgumentParser()
parser.add_argument("--quantization", type=str, default=None, choices=["fp8", "fp8_cs"])
parser.add_argument(
"--quantization", type=str, default=None, choices=["fp8", "fp8_cs", "fp8_block"]
)
args = parser.parse_args(argv, namespace)
dp_group = dist.new_group(backend="nccl")
......
......@@ -28,7 +28,7 @@ def _run_test(quantization):
assert result.returncode == 0
@pytest.mark.parametrize("quantization", ["fp8", "fp8_cs"])
@pytest.mark.parametrize("quantization", ["fp8", "fp8_cs", "fp8_block"])
def test_cast_master_weights_to_fp8(quantization):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
......
......@@ -392,6 +392,110 @@ class TestFloat8BlockwiseTensor:
with pytest.raises(AssertionError):
torch.testing.assert_close(x_view.dequantize(), -x_hp, **_tols[fp8_dtype])
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2], ids=str)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32], ids=str)
@pytest.mark.parametrize(
"dims", [[16, 16, 512], [16, 16, 512, 16], [12, 7, 11], [13, 14, 16], [2, 3, 5]]
)
def test_view_and_reshape_1D(
self, fp8_dtype: tex.DType, dtype: torch.dtype, dims: List[int]
) -> None:
"""Test view operations that preserve tensor shape"""
device = "cuda"
def is_bitwise_equal(a, b):
if a.numel() != b.numel():
return False
a_flat = a.reshape(-1).view(torch.uint8)
b_flat = b.reshape(-1).view(torch.uint8)
return torch.all((a_flat ^ b_flat) == 0)
x_hp = torch.rand(dims, dtype=dtype, device=device)
quantizer = Float8BlockQuantizer(
fp8_dtype=fp8_dtype,
rowwise=True,
columnwise=True,
block_scaling_dim=1,
)
x_fp8 = quantizer.make_empty(x_hp.shape, dtype=dtype, device=device)
quantizer.update_quantized(x_hp.clone(), x_fp8)
# Test view, high dimension tensor -> 2D tensor
x_hp_view = x_hp.view(-1, dims[-1]).contiguous()
x_fp8_view = x_fp8.view(-1, dims[-1])
# Check the dequantized result
torch.testing.assert_close(
x_fp8_view.dequantize().contiguous(), x_hp_view, **_tols[fp8_dtype]
)
# Check the bitwise equality of the inner data
assert is_bitwise_equal(x_fp8_view._rowwise_data, x_fp8._rowwise_data)
assert is_bitwise_equal(x_fp8_view._rowwise_scale_inv, x_fp8._rowwise_scale_inv)
# Check the data ptr
assert x_fp8_view._rowwise_data.data_ptr() == x_fp8._rowwise_data.data_ptr()
assert x_fp8_view._rowwise_scale_inv.data_ptr() == x_fp8._rowwise_scale_inv.data_ptr()
# Test reshape high dimension tensor -> 2D tensor
x_hp_reshape = x_hp.reshape(-1, dims[-1]).contiguous()
x_fp8_reshape = x_fp8.reshape(-1, dims[-1])
# Check the dequantized result
torch.testing.assert_close(
x_fp8_reshape.dequantize().contiguous(), x_hp_reshape, **_tols[fp8_dtype]
)
# Check the bitwise equality of the inner data
assert is_bitwise_equal(x_fp8_reshape._rowwise_data, x_fp8._rowwise_data)
assert is_bitwise_equal(x_fp8_reshape._rowwise_scale_inv, x_fp8._rowwise_scale_inv)
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2], ids=str)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32], ids=str)
@pytest.mark.parametrize("dims", [[16, 16, 512, 16], [2, 512, 512, 128], [3, 13, 14, 16]])
def test_view_and_reshape_2D(
self, fp8_dtype: tex.DType, dtype: torch.dtype, dims: List[int]
) -> None:
"""Test view operations that preserve tensor shape"""
device = "cuda"
def is_bitwise_equal(a, b):
if a.numel() != b.numel():
return False
a_flat = a.reshape(-1).view(torch.uint8)
b_flat = b.reshape(-1).view(torch.uint8)
return torch.all((a_flat ^ b_flat) == 0)
x_hp = torch.rand(dims, dtype=dtype, device=device)
quantizer = Float8BlockQuantizer(
fp8_dtype=fp8_dtype,
rowwise=True,
columnwise=True,
block_scaling_dim=2,
)
x_fp8 = quantizer.make_empty(x_hp.shape, dtype=dtype, device=device)
quantizer.update_quantized(x_hp.clone(), x_fp8)
# Test view, high dimension tensor -> 2D tensor
x_hp_view = x_hp.view(-1, dims[-2], dims[-1]).contiguous()
x_fp8_view = x_fp8.view(-1, dims[-2], dims[-1])
# Check the dequantized result
torch.testing.assert_close(
x_fp8_view.dequantize().contiguous(), x_hp_view, **_tols[fp8_dtype]
)
# Check the bitwise equality of the inner data
assert is_bitwise_equal(x_fp8_view._rowwise_data, x_fp8._rowwise_data)
assert is_bitwise_equal(x_fp8_view._rowwise_scale_inv, x_fp8._rowwise_scale_inv)
# Check the data ptr
assert x_fp8_view._rowwise_data.data_ptr() == x_fp8._rowwise_data.data_ptr()
assert x_fp8_view._rowwise_scale_inv.data_ptr() == x_fp8._rowwise_scale_inv.data_ptr()
# Test reshape high dimension tensor -> 2D tensor
x_hp_reshape = x_hp.reshape(-1, dims[-2], dims[-1]).contiguous()
x_fp8_reshape = x_fp8.reshape(-1, dims[-2], dims[-1])
# Check the dequantized result
torch.testing.assert_close(
x_fp8_reshape.dequantize().contiguous(), x_hp_reshape, **_tols[fp8_dtype]
)
# Check the bitwise equality of the inner data
assert is_bitwise_equal(x_fp8_reshape._rowwise_data, x_fp8._rowwise_data)
assert is_bitwise_equal(x_fp8_reshape._rowwise_scale_inv, x_fp8._rowwise_scale_inv)
@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3], ids=str)
@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str)
@pytest.mark.parametrize("dims", [[256, 512], [250, 500]])
......
......@@ -264,6 +264,15 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reductio
const std::string &amax_compute_algo,
transformer_engine::DType fp8_dtype, float margin);
// Note that the start_offset is the logical offset along the tensor dimension.
// The offset in bytes is start_offset * sizeof(tensor.dtype)
void fp8_block_scaling_compute_partial_amax(const at::Tensor &tensor, at::Tensor amax, size_t h,
size_t w, size_t start_offset, size_t block_len);
void fp8_block_scaling_partial_cast(const at::Tensor &inp, at::Tensor out, const at::Tensor &scale,
size_t h, size_t w, size_t start_offset, size_t block_len,
const transformer_engine::DType out_dtype);
/***************************************************************************************************
* Rotary positional embedding
**************************************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "common/common.h"
#include "common/utils.cuh"
#include "extensions.h"
#include "type_shim.h"
constexpr int kTileDim = 128;
constexpr int kThreadsPerBlock = 256;
template <typename IType>
__global__ void __launch_bounds__(kThreadsPerBlock)
fp8_block_scaling_compute_partial_amax_kernel(const IType *input, float *amax_ptr,
const size_t amax_stride_h,
const size_t amax_stride_w, const size_t h,
const size_t w, const size_t start_offset,
const size_t len) {
constexpr int kThreadsPerWarp = 32;
constexpr int kLoopsPerRow = kTileDim / kThreadsPerWarp;
constexpr int kNumWarps = kThreadsPerBlock / kThreadsPerWarp;
constexpr int kLoopsPerCol = kTileDim / kNumWarps;
const int tile_col = blockIdx.x;
const int tile_row = blockIdx.y;
const size_t end_offset = start_offset + len;
const IType *input_minus_offset = input - start_offset;
__shared__ float smem[kNumWarps];
float amax = 0.0f;
for (int loop_col = 0; loop_col < kLoopsPerCol; ++loop_col) {
size_t r = tile_row * kTileDim + loop_col * kNumWarps + threadIdx.x / kThreadsPerWarp;
for (int loop_row = 0; loop_row < kLoopsPerRow; ++loop_row) {
size_t c = tile_col * kTileDim + loop_row * kThreadsPerWarp + (threadIdx.x % kThreadsPerWarp);
size_t idx = r * w + c;
if (r < h && c < w && idx >= start_offset && idx < end_offset) {
float other_amax = fabs(static_cast<float>(input_minus_offset[idx]));
__builtin_assume(amax >= 0);
__builtin_assume(other_amax >= 0);
amax = fmaxf(amax, other_amax);
}
}
}
for (int delta = kThreadsPerWarp / 2; delta > 0; delta /= 2) {
float other_amax = __shfl_down_sync(0xFFFFFFFF, amax, delta);
__builtin_assume(amax >= 0);
__builtin_assume(other_amax >= 0);
amax = fmaxf(amax, other_amax);
}
if (threadIdx.x % kThreadsPerWarp == 0) {
smem[threadIdx.x / kThreadsPerWarp] = amax;
}
__syncthreads();
if (threadIdx.x == 0) {
for (int i = 0; i < kNumWarps; ++i) {
float other_amax = smem[i];
__builtin_assume(amax >= 0);
__builtin_assume(other_amax >= 0);
amax = fmaxf(amax, other_amax);
}
amax_ptr[tile_row * amax_stride_h + tile_col * amax_stride_w] = amax;
}
}
template <typename IType, typename OType, bool kWidthAligned>
__global__ void __launch_bounds__(kThreadsPerBlock)
fp8_block_scaling_partial_cast_kernel(const IType *input, OType *output, const float *scale_ptr,
const size_t scale_stride_h, const size_t scale_stride_w,
const size_t h, const size_t w, const size_t start_offset,
const size_t len) {
using transformer_engine::Vec;
static_assert(sizeof(OType) == 1);
constexpr int kNumOutputElemsPerBank = 4 / sizeof(OType);
constexpr int kThreadsPerWarp = 32;
constexpr int kLoopsPerRow = kTileDim / kThreadsPerWarp;
constexpr int kNumWarps = kThreadsPerBlock / kThreadsPerWarp;
constexpr int kRowsPerWarp = kTileDim / kNumWarps;
__shared__ OType smem[kTileDim][kTileDim + kNumOutputElemsPerBank];
const int tile_w = blockIdx.x;
const int tile_h = blockIdx.y;
const size_t end_offset = start_offset + len;
const IType *input_minus_offset = input - start_offset;
OType *output_minus_offset = output - start_offset;
const float scale = scale_ptr[tile_h * scale_stride_h + tile_w * scale_stride_w];
// Load input data into shared memory
bool skip_store = true;
for (int i = 0; i < kRowsPerWarp; ++i) {
for (int j = 0; j < kLoopsPerRow; ++j) {
const int h_in_smem = threadIdx.x / kThreadsPerWarp * kRowsPerWarp + i;
const int w_in_smem = threadIdx.x % kThreadsPerWarp + kThreadsPerWarp * j;
const int h_in_input = tile_h * kTileDim + h_in_smem;
const int w_in_input = tile_w * kTileDim + w_in_smem;
const size_t idx_in_input = static_cast<size_t>(h_in_input) * w + w_in_input;
if (h_in_input < h && w_in_input < w && idx_in_input >= start_offset &&
idx_in_input < end_offset) {
float inp = static_cast<float>(input_minus_offset[idx_in_input]) * scale;
smem[h_in_smem][w_in_smem] = static_cast<OType>(inp);
skip_store = false;
}
}
}
for (int delta = kThreadsPerWarp / 2; delta > 0; delta /= 2) {
bool other_skip_store = __shfl_down_sync(0xFFFFFFFF, skip_store, delta);
skip_store = skip_store && other_skip_store;
}
skip_store = __shfl_sync(0xFFFFFFFF, skip_store, 0);
if (skip_store) {
return;
}
// Store the casted data into the output.
// Note that this store operation might write "out-of-bounds", but it is intentional:
// 1. The "out-of-bounds" here only crosses the boundary of the "local shard" (i.e., the region
// from start_offset to end_offset), not the boundary of the entire output memory. Therefore,
// this out-of-bounds write will not cause illegal memory access.
// 2. We assume that the subsequent all-gather operation happens in-place, so any parts that
// should not be updated here will be overwritten by the all-gather.
// This tricky approach allows us to avoid checking whether each output index falls within
// [start, end), resulting in a significant performance improvement.
Vec<OType, kNumOutputElemsPerBank> vec_output;
for (int i = 0; i < kRowsPerWarp; ++i) {
const int row_in_smem = threadIdx.x / kThreadsPerWarp * kRowsPerWarp + i;
const int col_in_smem = threadIdx.x % kThreadsPerWarp * kNumOutputElemsPerBank;
for (int j = 0; j < kNumOutputElemsPerBank; ++j) {
vec_output.data.elt[j] = smem[row_in_smem][col_in_smem + j];
}
const int row_in_output = tile_h * kTileDim + row_in_smem;
const int col_in_output = tile_w * kTileDim + col_in_smem;
const size_t idx_in_output = static_cast<size_t>(row_in_output) * w + col_in_output;
if (row_in_output < h) {
if constexpr (kWidthAligned) {
vec_output.store_to(output_minus_offset + idx_in_output);
} else {
int num = min(static_cast<size_t>(kNumOutputElemsPerBank),
static_cast<size_t>(col_in_output < w ? w - col_in_output : 0));
vec_output.store_to_elts(output_minus_offset, idx_in_output, num);
}
}
}
}
void fp8_block_scaling_compute_partial_amax(const at::Tensor &tensor, at::Tensor amax, size_t h,
size_t w, size_t start_offset, size_t block_len) {
TORCH_CHECK(block_len == 128, "Currently only support block_len = 128");
TORCH_CHECK(amax.dim() == 2, "amax must be a 2D tensor");
TORCH_CHECK(amax.scalar_type() == at::ScalarType::Float, "amax must be a float tensor");
TORCH_CHECK(tensor.scalar_type() == at::ScalarType::Float ||
tensor.scalar_type() == at::ScalarType::BFloat16,
"tensor must be a float or bfloat16 tensor");
size_t amax_stride_h = amax.stride(0);
size_t amax_stride_w = amax.stride(1);
size_t len = tensor.numel();
assert(h > 0 && w > 0);
assert(start_offset < h * w);
assert(start_offset + len <= h * w);
size_t blocks_x = (w + kTileDim - 1) / kTileDim;
size_t blocks_y = (h + kTileDim - 1) / kTileDim;
assert(blocks_x <= std::numeric_limits<unsigned int>::max());
assert(blocks_y <= std::numeric_limits<unsigned int>::max());
dim3 grid(blocks_x, blocks_y);
auto stream = at::cuda::getCurrentCUDAStream();
DISPATCH_FLOAT_HALF_AND_BFLOAT(tensor.scalar_type(), 0, "compute_partial_amax",
fp8_block_scaling_compute_partial_amax_kernel<scalar_t_0>
<<<grid, kThreadsPerBlock, 0, stream>>>(
tensor.data_ptr<scalar_t_0>(), amax.data_ptr<float>(),
amax_stride_h, amax_stride_w, h, w, start_offset, len);)
}
void fp8_block_scaling_partial_cast(const at::Tensor &inp, at::Tensor out, const at::Tensor &scale,
size_t h, size_t w, size_t start_offset, size_t block_len,
const transformer_engine::DType out_dtype) {
TORCH_CHECK(block_len == 128, "Currently only support block_len = 128");
TORCH_CHECK(scale.dim() == 2, "scale must be a 2D tensor");
TORCH_CHECK(scale.scalar_type() == at::ScalarType::Float, "scale must be a float tensor");
TORCH_CHECK(
inp.scalar_type() == at::ScalarType::Float || inp.scalar_type() == at::ScalarType::BFloat16,
"input must be a float or bfloat16 tensor");
TORCH_CHECK(out.scalar_type() == at::ScalarType::Byte, "output must be a uint8 tensor");
TORCH_CHECK(out_dtype == transformer_engine::DType::kFloat8E4M3 ||
out_dtype == transformer_engine::DType::kFloat8E5M2,
"out_dtype must be kFloat8E4M3 or kFloat8E5M2");
size_t scale_stride_h = scale.stride(0);
size_t scale_stride_w = scale.stride(1);
size_t len = inp.numel();
assert(h > 0 && w > 0);
assert(start_offset < h * w);
assert(start_offset + len <= h * w);
size_t blocks_x = (w + kTileDim - 1) / kTileDim;
size_t blocks_y = (h + kTileDim - 1) / kTileDim;
assert(blocks_x <= std::numeric_limits<unsigned int>::max());
assert(blocks_y <= std::numeric_limits<unsigned int>::max());
dim3 grid(blocks_x, blocks_y);
auto stream = at::cuda::getCurrentCUDAStream();
DISPATCH_FLOAT_HALF_AND_BFLOAT(
inp.scalar_type(), 0, "partial_cast",
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
out_dtype, fp8_type,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
w % kTileDim == 0, kWidthAligned,
fp8_block_scaling_partial_cast_kernel<scalar_t_0, fp8_type, kWidthAligned>
<<<grid, kThreadsPerBlock, 0, stream>>>(inp.data_ptr<scalar_t_0>(),
reinterpret_cast<fp8_type *>(out.data_ptr()),
scale.data_ptr<float>(), scale_stride_h,
scale_stride_w, h, w, start_offset, len);)))
}
......@@ -210,6 +210,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused_amax_and_scale_update_after_reduction", &fused_amax_and_scale_update_after_reduction,
"Update amax history and FP8 scale/scale_inv after reduction",
py::call_guard<py::gil_scoped_release>());
m.def("fp8_block_scaling_compute_partial_amax", &fp8_block_scaling_compute_partial_amax,
"Compute partial amax from master weights for fp8 block scaling", py::arg("tensor"),
py::arg("amax"), py::arg("h"), py::arg("w"), py::arg("start_offset"), py::arg("block_len"));
m.def("fp8_block_scaling_partial_cast", &fp8_block_scaling_partial_cast,
"Partial cast from master weights for fp8 block scaling", py::arg("inp"), py::arg("out"),
py::arg("scale"), py::arg("h"), py::arg("w"), py::arg("start_offset"), py::arg("block_len"),
py::arg("out_dtype"));
m.def("fused_multi_row_padding", &fused_multi_row_padding, "Fused Multi-tensor padding",
py::call_guard<py::gil_scoped_release>());
......
......@@ -9,6 +9,7 @@ import math
from typing import Optional, Dict, Any, Tuple
import torch
import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType
from ...constants import TE_DType_To_Torch
......@@ -232,6 +233,38 @@ class Float8BlockwiseQTensorBase:
reordered.append(dims[0])
return torch.Size(reordered)
def _create_columnwise(self):
"""
Update columnwise data and columnwise scale inv. Can only be used when using 2D scaling.
"""
assert self._is_2D_scaled, "Cannot create columnwise data when not using 2D scaling."
rowwise_data = self._rowwise_data
if not rowwise_data.is_contiguous():
rowwise_data = rowwise_data.contiguous()
self._columnwise_data = tex.fp8_transpose(
rowwise_data, self._fp8_dtype, out=self._columnwise_data
)
if self._columnwise_scale_inv is None:
assert self._quantizer is not None, (
"._quantizer of Float8BlockwiseQTensor cannot be None because all the blockwise "
"quantized tensors are supposed to be generated from the quantizer."
)
columnwise_scale_inv_shape = self._quantizer.get_scale_shape(rowwise_data.shape, True)
self._columnwise_scale_inv = torch.empty(
columnwise_scale_inv_shape,
dtype=self._rowwise_scale_inv.dtype,
device=self._rowwise_scale_inv.device,
)
assert len(self._rowwise_scale_inv.shape) == 2
assert len(self._columnwise_scale_inv.shape) == 2
rowwise_scale_inv = self._rowwise_scale_inv
columnwise_scale_inv = rowwise_scale_inv.transpose(-2, -1)
h = min(self._columnwise_scale_inv.shape[0], columnwise_scale_inv.shape[0])
w = min(self._columnwise_scale_inv.shape[1], columnwise_scale_inv.shape[1])
self._columnwise_scale_inv[0:h, 0:w].copy_(columnwise_scale_inv[0:h, 0:w])
def __repr__(self):
if self._rowwise_data is not None:
data = self.dequantize()
......
......@@ -325,12 +325,23 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
), "Must retain some data either columnwise or rowwise"
if columnwise_usage and rowwise_usage:
if not self._is_2D_scaled:
# For 1D scaling, we cannot create columnwise data/scale_inv from rowwise
# data/scale_inv because their scale values are different.
assert (
self._rowwise_data is not None
and self._rowwise_scale_inv is not None
and self._columnwise_data is not None
and self._columnwise_scale_inv is not None
), "Cannot update to rowwise and columnwise usage."
else:
# For 2D scaling, if columnwise data/scale_inv is None, we can create them from
# rowwise data/scale_inv.
assert (
self._rowwise_data is not None and self._rowwise_scale_inv is not None
), "Cannot update to rowwise and columnwise usage because rowwise data is None."
if self._columnwise_data is None or self._columnwise_scale_inv is None:
self._create_columnwise()
return
if rowwise_usage:
......@@ -544,15 +555,65 @@ class _ViewFunc(torch.autograd.Function):
# pylint: disable=missing-function-docstring
# Return input tensor if shape is not provided
if ctx is not None:
ctx.shape = tensor.shape
if shape is None:
return tensor
if list(shape) != list(tensor.shape):
raise NotImplementedError("View not implemented.")
# Canonicalize shape
if not isinstance(shape, Iterable):
shape = [shape]
elif len(shape) == 1 and isinstance(shape[0], Iterable):
shape = shape[0]
if -1 in shape:
shape = list(shape)
d_inferred = -math.prod(ctx.shape) // math.prod(shape)
for i, d in enumerate(shape):
if d == -1:
shape[i] = d_inferred
break
if tensor._is_2D_scaled:
# For the case of 2D scaled tensor, the last 2 dimensions should not change
if shape[-1] != ctx.shape[-1] or shape[-2] != ctx.shape[-2]:
raise RuntimeError(
"2D scaled Float8BlockwiseQTensor does not support view "
"the last 2 dimensions "
f"(attempted to view dims={tuple(tensor.shape)} to {tuple(shape)})"
)
else:
# For the case of 1D scaled tensor, the last dimension should not change
if shape[-1] != ctx.shape[-1]:
raise RuntimeError(
"1D scaled Float8BlockwiseQTensor does not support view "
"the last dimension "
f"(attempted to view dims={tuple(tensor.shape)} to {tuple(shape)})"
)
if list(shape) == list(tensor.shape):
return tensor
# Construct new tensor if shape is provided
new_rowwise_data = None
new_columnwise_data = None
if tensor._rowwise_data is not None:
new_rowwise_data = tensor._rowwise_data.view(*shape)
if tensor._columnwise_data is not None:
columnwise_shape = [shape[-1]] + list(shape[:-1])
new_columnwise_data = tensor._columnwise_data.view(columnwise_shape)
return Float8BlockwiseQTensor(
shape=shape,
dtype=tensor.dtype,
fp8_dtype=tensor._fp8_dtype,
rowwise_data=new_rowwise_data,
rowwise_scale_inv=tensor._rowwise_scale_inv,
columnwise_data=new_columnwise_data,
columnwise_scale_inv=tensor._columnwise_scale_inv,
quantizer=tensor._quantizer,
is_2D_scaled=tensor._is_2D_scaled,
requires_grad=tensor.requires_grad,
)
@staticmethod
def backward(
ctx,
......@@ -561,7 +622,27 @@ class _ViewFunc(torch.autograd.Function):
# pylint: disable=missing-function-docstring
if isinstance(grad, Float8BlockwiseQTensor):
raise NotImplementedError("View bwd not implemented")
new_data = (
grad._rowwise_data.view(*ctx.shape) if grad._rowwise_data is not None else None
)
if grad._columnwise_data is not None:
columnwise_shape = [ctx.shape[-1]] + list(ctx.shape[:-1])
new_columnwise_data = grad._columnwise_data.view(columnwise_shape)
else:
new_columnwise_data = None
dgrad = Float8BlockwiseQTensor(
shape=ctx.shape,
dtype=grad.dtype,
rowwise_data=new_data,
rowwise_scale_inv=grad._rowwise_scale_inv,
columnwise_data=new_columnwise_data,
columnwise_scale_inv=grad._columnwise_scale_inv,
fp8_dtype=grad._fp8_dtype,
quantizer=grad._quantizer,
is_2D_scaled=grad._is_2D_scaled,
requires_grad=grad.requires_grad,
)
return dgrad, None
return grad.view(ctx.shape), None
......@@ -581,7 +662,6 @@ class _ReshapeFunc(torch.autograd.Function):
# pylint: disable=missing-function-docstring
# Return input tensor if shape is not provided
if ctx is not None:
ctx.shape = tensor.shape
if shape is None:
return tensor
......@@ -598,10 +678,48 @@ class _ReshapeFunc(torch.autograd.Function):
if d == -1:
shape[i] = d_inferred
break
if list(shape) != list(tensor.shape):
raise NotImplementedError("Reshape not implemented yet.")
if tensor._is_2D_scaled:
# For the case of 2D scaled tensor, the last 2 dimensions should not change
if shape[-1] != ctx.shape[-1] or shape[-2] != ctx.shape[-2]:
raise RuntimeError(
"2D scaled Float8BlockwiseQTensor does not support reshaping "
"the last 2 dimensions "
f"(attempted to reshape dims={tuple(tensor.shape)} to {tuple(shape)})"
)
else:
# For the case of 1D scaled tensor, the last dimension should not change
if shape[-1] != ctx.shape[-1]:
raise RuntimeError(
"1D scaled Float8BlockwiseQTensor does not support reshaping "
"the last dimension "
f"(attempted to reshape dims={tuple(tensor.shape)} to {tuple(shape)})"
)
if list(shape) == list(tensor.shape):
return tensor
# Construct new tensor if shape is provided
new_rowwise_data = None
new_columnwise_data = None
if tensor._rowwise_data is not None:
new_rowwise_data = tensor._rowwise_data.reshape(*shape)
if tensor._columnwise_data is not None:
columnwise_shape = [shape[-1]] + list(shape[:-1])
new_columnwise_data = tensor._columnwise_data.view(columnwise_shape)
return Float8BlockwiseQTensor(
shape=shape,
dtype=tensor.dtype,
fp8_dtype=tensor._fp8_dtype,
rowwise_data=new_rowwise_data,
rowwise_scale_inv=tensor._rowwise_scale_inv,
columnwise_data=new_columnwise_data,
columnwise_scale_inv=tensor._columnwise_scale_inv,
quantizer=tensor._quantizer,
is_2D_scaled=tensor._is_2D_scaled,
requires_grad=tensor.requires_grad,
)
@staticmethod
def backward(
ctx,
......@@ -610,5 +728,24 @@ class _ReshapeFunc(torch.autograd.Function):
# pylint: disable=missing-function-docstring
if isinstance(grad, Float8BlockwiseQTensor):
raise NotImplementedError("Reshape bwd not implemented yet.")
new_rowwise_data = None
new_columnwise_data = None
if grad._rowwise_data is not None:
new_rowwise_data = grad._rowwise_data.view(*ctx.shape)
if grad._columnwise_data is not None:
columnwise_shape = [ctx.shape[-1]] + list(ctx.shape[:-1])
new_columnwise_data = grad._columnwise_data.view(columnwise_shape)
dgrad = Float8BlockwiseQTensor(
shape=ctx.shape,
dtype=grad.dtype,
rowwise_data=new_rowwise_data,
rowwise_scale_inv=grad._rowwise_scale_inv,
columnwise_data=new_columnwise_data,
columnwise_scale_inv=grad._columnwise_scale_inv,
fp8_dtype=grad._fp8_dtype,
quantizer=grad._quantizer,
is_2D_scaled=grad._is_2D_scaled,
requires_grad=grad.requires_grad,
)
return dgrad, None
return grad.view(ctx.shape), None
......@@ -12,6 +12,7 @@ from transformer_engine_torch import multi_tensor_scale, multi_tensor_compute_sc
from .quantized_tensor import QuantizedTensor
from .float8_tensor import Float8Tensor, Float8Quantizer, Float8CurrentScalingQuantizer
from .mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer
from .float8_blockwise_tensor import Float8BlockwiseQTensor, Float8BlockQuantizer
from ..optimizers.multi_tensor_apply import multi_tensor_applier
......@@ -32,6 +33,12 @@ def replace_raw_data(tensor: QuantizedTensor, new_raw_data: torch.Tensor):
new_raw_data.detach().copy_(old_raw_data)
tensor._data = new_raw_data
del old_raw_data
elif isinstance(tensor, Float8BlockwiseQTensor):
old_raw_data = tensor._rowwise_data
assert old_raw_data.dtype == new_raw_data.dtype, "The data types of raw data don't match"
new_raw_data.detach().copy_(old_raw_data)
tensor._rowwise_data = new_raw_data
del old_raw_data
elif isinstance(tensor, MXFP8Tensor):
raise NotImplementedError("replace_raw_data for MXFP8Tensor is not supported yet")
else:
......@@ -65,6 +72,7 @@ def cast_master_weights_to_fp8(
delayed_scaling_params = []
current_scaling_params = []
blockwise_scaling_params = []
if fsdp_shard_model_weights is None:
use_fsdp_shard_model_weights = False
......@@ -106,6 +114,10 @@ def cast_master_weights_to_fp8(
current_scaling_params.append(
(model_weight, master_weight, start_offset, fsdp_shard_model_weight)
)
elif isinstance(quantizer, Float8BlockQuantizer):
blockwise_scaling_params.append(
(model_weight, master_weight, start_offset, fsdp_shard_model_weight)
)
elif isinstance(quantizer, MXFP8Quantizer):
raise NotImplementedError(
"cast_master_weights_to_fp8 for MXFP8BlockScaling is not supported yet"
......@@ -123,6 +135,10 @@ def cast_master_weights_to_fp8(
_cast_master_weights_to_fp8_current_scaling(
current_scaling_params, group, use_fsdp_shard_model_weights
)
if len(blockwise_scaling_params) > 0:
_cast_master_weights_to_fp8_blockwise_scaling(
blockwise_scaling_params, group, use_fsdp_shard_model_weights
)
def _cast_master_weights_to_fp8_delayed_scaling(params, group, use_fsdp_shard_model_weights=False):
......@@ -313,3 +329,125 @@ def _cast_master_weights_to_fp8_current_scaling(params, group, use_fsdp_shard_mo
model_weight.dtype,
)
quantizer.update_quantized(master_weight, model_weight_fragment)
def _cast_master_weights_to_fp8_blockwise_scaling(
params, group, use_fsdp_shard_model_weights=False
):
r"""Helper function to cast master weights to FP8 primary weights for blockwise scaling.
Parameters
----------
params : List of tuple, each tuple contains a model weight, a master weight, and an offset
indicating the starting index of the master weight in the model weight.
group : The distributed group to do amax reduction. Typically it's the data parallel
group.
use_fsdp_shard_model_weights : bool, if True, it means that the model weights are sharded.
"""
# Parameter attributes
device = params[0][0].device
block_len = params[0][0]._get_quantizer().block_len
fp8_dtype = params[0][0]._get_quantizer().dtype
force_pow_2_scales = params[0][0]._get_quantizer().force_pow_2_scales
amax_epsilon = params[0][0]._get_quantizer().amax_epsilon
# Create a dummy overflow buffer, it's needed by multi_tensor_applier.
dummy_overflow_buf = torch.zeros(1, dtype=torch.int, device=device)
# Get the total number of amax elements in all the model weights.
cu_amax_sizes = [0]
for model_weight, _, _, _ in params:
scale_shape = model_weight._get_quantizer().get_scale_shape(model_weight.shape, False)
num_amaxes = scale_shape[0] * scale_shape[1]
cu_amax_sizes.append(cu_amax_sizes[-1] + num_amaxes)
# Create a contiguous buffer to store amaxes temporarily, so we can perform all all-reduce
# NCCL kernels at once.
packed_amaxes = torch.zeros(cu_amax_sizes[-1], dtype=torch.float32, device=device)
# ---------------------------------------------------------------------------------------------
# Step 1: Iterate through all the none empty master weights and compute amax of them. Store the
# amaxes in a contiguous buffer. If a block of a master weight is empty, the
# corresponding amax will be set to 0.
# ---------------------------------------------------------------------------------------------
amaxes, scales, scale_invs = [], [], []
for i, (model_weight, master_weight, start_offset, _) in enumerate(params):
# Make sure all the model weights have the same numerical options.
quantizer = model_weight._get_quantizer()
assert block_len == quantizer.block_len
assert fp8_dtype == quantizer.dtype
assert force_pow_2_scales == quantizer.force_pow_2_scales
assert amax_epsilon == quantizer.amax_epsilon
scale_shape = quantizer.get_scale_shape(model_weight.shape, False)
amax = packed_amaxes[cu_amax_sizes[i] : cu_amax_sizes[i + 1]].reshape(scale_shape)
scale = torch.empty(scale_shape, dtype=torch.float32, device=device)
scale_inv = model_weight._rowwise_scale_inv
assert len(scale_shape) == 2
assert len(scale_inv.shape) == 2
assert scale_inv.shape[0] == scale_shape[0]
assert scale_inv.shape[1] == scale_shape[1]
amaxes.append(amax)
scales.append(scale)
scale_invs.append(scale_inv)
# Compute amax of the master weight and store it in packed_amaxes.
if master_weight is not None:
assert len(model_weight.shape) == 2
h, w = model_weight.shape
tex.fp8_block_scaling_compute_partial_amax(
master_weight, amax, h, w, start_offset, block_len
)
# ---------------------------------------------------------------------------------------------
# Step 2: Perform all-reduce on packed_amaxes to get the global amax.
# ---------------------------------------------------------------------------------------------
torch.distributed.all_reduce(packed_amaxes, op=torch.distributed.ReduceOp.MAX, group=group)
# ---------------------------------------------------------------------------------------------
# Step 3: Update scales and scale_invs.
# ---------------------------------------------------------------------------------------------
if fp8_dtype == tex.DType.kFloat8E4M3:
max_fp8 = 448.0
elif fp8_dtype == tex.DType.kFloat8E5M2:
max_fp8 = 57344.0
else:
raise ValueError(f"Unsupported FP8 dtype: {fp8_dtype}")
multi_tensor_applier(
multi_tensor_compute_scale_and_scale_inv,
dummy_overflow_buf,
[amaxes, scales, scale_invs],
max_fp8,
force_pow_2_scales,
amax_epsilon,
)
# ---------------------------------------------------------------------------------------------
# Step 4: Cast master weights to FP8.
# ---------------------------------------------------------------------------------------------
for (model_weight, master_weight, start_offset, model_weight_fragment), scale in zip(
params, scales
):
# Clear columnwise data for all model weights.
# We cannot create columnwise data here because users (like megatron) may want to overlap
# the all-gather of model weights and forward process, so the model weight is not updated
# at this moment.
model_weight.update_usage(rowwise_usage=True, columnwise_usage=False)
# If master weight is None, it means that the master weight of the current model weight
# is in other DP ranks.
if master_weight is None:
continue
# Cast master weight to FP8
end_offset = start_offset + master_weight.numel()
if not use_fsdp_shard_model_weights:
model_weight_fragment = model_weight._rowwise_data.reshape(-1)[start_offset:end_offset]
assert len(model_weight.shape) == 2
h, w = model_weight.shape
tex.fp8_block_scaling_partial_cast(
master_weight, model_weight_fragment, scale, h, w, start_offset, block_len, fp8_dtype
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment