Unverified Commit de69ca0e authored by hx's avatar hx Committed by GitHub
Browse files

[PyTorch] fix input_quantizer usage for save_original_input; fix blockwise FP8...


[PyTorch] fix input_quantizer usage for save_original_input; fix blockwise FP8 convert_and_update_tensor (#1978)

* fix input_quantizer in save_original_input bwd
Signed-off-by: default avatarHongxiao Bai <hongxiaob@nvidia.com>

* fix get shape of blockwise tensor with only compact colwise data
Signed-off-by: default avatarHongxiao Bai <hongxiaob@nvidia.com>

* fix blockwise FP8 convert_and_update_tensor
Signed-off-by: default avatarHongxiao Bai <hongxiaob@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarHongxiao Bai <hongxiaob@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent c0d2f1a5
......@@ -219,7 +219,7 @@ class TestFloat8BlockwiseTensor:
rowwise=True,
columnwise=dq_columnwise,
block_scaling_dim=block_scaling_dim,
all_gather_usage=True,
all_gather_usage=(block_scaling_dim == 1),
)
self._test_quantize_dequantize(
quantizer=quantizer,
......
......@@ -671,13 +671,128 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::convert_and_update_te
const DType dtype = tensor.attr("_fp8_dtype").cast<DType>();
bool is_2D_scaled = tensor.attr("_is_2D_scaled").cast<bool>();
// Check the data matches quantizer usages
NVTE_CHECK(!tensor.attr("_rowwise_data").is_none() == rowwise_usage,
"Float8BlockwiseQTensor does not match quantizer usages (has_rowwise_data=",
!tensor.attr("_rowwise_data").is_none(), ", rowwise_usage=", rowwise_usage);
NVTE_CHECK(!tensor.attr("_columnwise_data").is_none() == columnwise_usage,
"Float8BlockwiseQTensor does not match quantizer usages (has_columnwise_data=",
!tensor.attr("_columnwise_data").is_none(), ", columnwise_usage=", columnwise_usage);
// Extract buffers from Python tensor
auto get_tensor = [&tensor](const char* name) -> std::optional<at::Tensor> {
auto attr_py = tensor.attr(name);
if (attr_py.is_none()) {
return std::nullopt;
}
return attr_py.cast<at::Tensor>();
};
auto rowwise_data = get_tensor("_rowwise_data");
auto rowwise_scale_inv = get_tensor("_rowwise_scale_inv");
auto columnwise_data = get_tensor("_columnwise_data");
auto columnwise_scale_inv = get_tensor("_columnwise_scale_inv");
NVTE_CHECK(rowwise_data || columnwise_data, "FP8BlockwiseTensor has no data.");
// Tensor options and dimensions
at::TensorOptions opts;
at::TensorOptions scale_opts;
opts = opts.dtype(torch::kUInt8).device(torch::kCUDA);
scale_opts = scale_opts.dtype(torch::kFloat32).device(torch::kCUDA);
auto get_columnwise_shape = [&columnwise_data](bool all_gather_usage) -> std::vector<size_t> {
if (!columnwise_data) {
return std::vector<size_t>();
}
if (all_gather_usage) {
return getTensorShape(*columnwise_data);
}
std::vector<size_t> shape = getTensorShape(*columnwise_data);
std::vector<size_t> shape_transposed(shape.size());
for (size_t i = 0; i + 1 < shape.size(); ++i) {
shape_transposed[i] = shape[i + 1];
}
if (shape.size() > 0) {
shape_transposed[shape.size() - 1] = shape[0];
}
return shape_transposed;
};
std::vector<size_t> shape;
if (rowwise_data) {
shape = getTensorShape(*rowwise_data);
if (columnwise_data) {
auto expected_shape = get_columnwise_shape(all_gather_usage);
NVTE_CHECK(shape == expected_shape, "BlockwiseFP8 row-wise data (shape=", shape,
") and column-wise data (shape=", expected_shape, ") do not match");
}
} else {
shape = get_columnwise_shape(all_gather_usage);
}
std::vector<int64_t> torch_shape;
for (auto s : shape) {
torch_shape.emplace_back(static_cast<int64_t>(s));
}
// Coerce row-wise data
if (rowwise_usage) {
if (!rowwise_data) {
rowwise_data = at::empty(torch_shape, opts);
tensor.attr("_rowwise_data") = *rowwise_data;
}
if (!rowwise_scale_inv) {
auto scale_shape = get_scale_shape(shape, false);
size_t sinv0 = scale_shape[0];
size_t sinv1 = scale_shape[1];
rowwise_scale_inv =
at::empty({static_cast<int64_t>(sinv0), static_cast<int64_t>(sinv1)}, scale_opts);
tensor.attr("_rowwise_scale_inv") = *rowwise_scale_inv;
}
} else { // rowwise_usage == false
if (rowwise_data) {
rowwise_data.reset();
tensor.attr("_rowwise_data") = py::none();
}
if (rowwise_scale_inv) {
rowwise_scale_inv.reset();
tensor.attr("_rowwise_scale_inv") = py::none();
}
}
// Coerce column-wise data
if (columnwise_usage) {
std::vector<size_t> columnwise_shape;
std::vector<int64_t> torch_columnwise_shape;
if (torch_shape.size() > 0) {
if (!all_gather_usage) {
torch_columnwise_shape.reserve(torch_shape.size());
columnwise_shape.reserve(shape.size());
torch_columnwise_shape.push_back(torch_shape[torch_shape.size() - 1]);
columnwise_shape.push_back(shape[shape.size() - 1]);
for (size_t i = 0; i < torch_shape.size() - 1; ++i) {
torch_columnwise_shape.push_back(torch_shape[i]);
columnwise_shape.push_back(shape[i]);
}
} else {
// assert we are doing 1D scaling
NVTE_CHECK(block_scaling_dim == 1,
"Compact columnwise format is not supported for 128x128 2D block scaling.");
torch_columnwise_shape = torch_shape;
columnwise_shape = shape;
}
}
if (!columnwise_data) {
columnwise_data = at::empty(torch_columnwise_shape, opts);
tensor.attr("_columnwise_data") = *columnwise_data;
}
if (!columnwise_scale_inv) {
auto scale_shape = get_scale_shape(shape, true);
size_t sinv0 = scale_shape[0];
size_t sinv1 = scale_shape[1];
columnwise_scale_inv =
at::empty({static_cast<int64_t>(sinv0), static_cast<int64_t>(sinv1)}, scale_opts);
tensor.attr("_columnwise_scale_inv") = *columnwise_scale_inv;
}
} else { // columnwise_usage == false
if (columnwise_data) {
columnwise_data.reset();
tensor.attr("_columnwise_data") = py::none();
}
if (columnwise_scale_inv) {
columnwise_scale_inv.reset();
tensor.attr("_columnwise_scale_inv") = py::none();
}
}
auto ret = TensorWrapper(is_2D_scaled ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D);
......
......@@ -589,13 +589,14 @@ class _Linear(torch.autograd.Function):
else:
# Quantize input tensor
quantizer = ctx.input_quantizer
if ctx.backward_input_needs_gather and isinstance(
quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
):
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
# All-gather is not supported with FP8 column-wise data
quantizer.set_usage(rowwise=True, columnwise=False)
quantizer.set_usage(
rowwise=True,
columnwise=not ctx.backward_input_needs_gather,
)
else:
quantizer.set_usage(rowwise=True, columnwise=True)
quantizer.set_usage(rowwise=False, columnwise=True)
inputmat = quantizer(inputmat)
else:
if isinstance(inputmat, QuantizedTensorBase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment