Unverified Commit 96ee7173 authored by Zhongbo Zhu's avatar Zhongbo Zhu Committed by GitHub
Browse files

[PyTorch][MoE] MXFP8 Support to Reduce CPU Overhead By Fuse Torch Empty Calls (#1934)



* functional passed
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* before zero padding in mxfp8 swizzle, use torch zeros to malloc for now
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* format
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* lint
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

---------
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>
parent 3c4dfffb
...@@ -9,15 +9,45 @@ import pandas as pd ...@@ -9,15 +9,45 @@ import pandas as pd
import pathlib import pathlib
from transformer_engine.pytorch.module import GroupedLinear from transformer_engine.pytorch.module import GroupedLinear
from transformer_engine.common.recipe import Float8BlockScaling from transformer_engine.common.recipe import Float8BlockScaling, MXFP8BlockScaling
from transformer_engine.pytorch.fp8 import fp8_autocast from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager
from contextlib import nullcontext from contextlib import nullcontext
"""
# Profile BF16 recipe with Nsight Systems
nsys profile \
--output=./benchmarks/linear/b200_mkn_4096_4096_4096_numgemm_8_bf16 \
--force-overwrite true \
--trace=cuda,nvtx,cudnn,cublas \
python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe bf16
# Profile FP8 sub-channel recipe with Nsight Systems
nsys profile \
--output=./benchmarks/linear/h100hbm_mkn_4096_4096_4096_numgemm_8_fp8_sub_channel \
--force-overwrite true \
--trace=cuda,nvtx,cudnn,cublas \
python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe fp8_sub_channel
# Profile MXFP8 recipe with Nsight Systems
nsys profile \
--output=./benchmarks/linear/b200_mkn_4096_4096_4096_numgemm_8_mxfp8 \
--force-overwrite true \
--trace=cuda,nvtx,cudnn,cublas \
python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe mxfp8
"""
RECIPES = { RECIPES = {
"bf16": None, "bf16": None,
"fp8_sub_channel": Float8BlockScaling(), "fp8_sub_channel": Float8BlockScaling(),
"mxfp8": MXFP8BlockScaling(),
} }
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
)
def run_linear_multiple_steps(layer, x, m_splits, mode, gradient, run_num_steps=1, recipe=None): def run_linear_multiple_steps(layer, x, m_splits, mode, gradient, run_num_steps=1, recipe=None):
assert mode in ["fwd_only", "fwd_bwd"] assert mode in ["fwd_only", "fwd_bwd"]
...@@ -187,36 +217,43 @@ if __name__ == "__main__": ...@@ -187,36 +217,43 @@ if __name__ == "__main__":
default="benchmark_output/", default="benchmark_output/",
help="output path for report", help="output path for report",
) )
# arguments for recipe, options are fp8_sub_channel, mxfp8, bf16, all
parser.add_argument(
"--recipe",
type=str,
default="bf16",
help="Recipe to use, options are fp8_sub_channel, mxfp8, bf16, or all",
)
args = parser.parse_args() args = parser.parse_args()
use_bias = False use_bias = False
# Set the MKN values to benchmark # Set the MKN values to benchmark
mkns = [] mkns = []
for m in [1024]: for m in [8192]:
# for m in [4096, 8192, 16384]: # for m in [4096, 8192, 16384]:
# for n in [1024, 2048, 4096, 8192, 16384]: # for n in [1024, 2048, 4096, 8192, 16384]:
for n in [3072]: for n in [8192]:
for k in [4096]: for k in [4096]:
mkns.append((m, k, n)) mkns.append((m, k, n))
# recipe_list = [ # default recipes to run if not specified
# "bf16", "fp8_sub_channel", recipe_list = ["bf16"]
# ]
recipe_list = [
"fp8_sub_channel",
]
# num_gemms_list = [16, 32] if args.recipe == "all":
num_gemms_list = [4] recipe_list = ["bf16", "fp8_sub_channel", "mxfp8"]
else:
recipe_list = [args.recipe]
num_gemms_list = [8]
if args.profile: if args.profile:
# nsys profile --output=./benchmarks/linear/mkn_4096_4096_4096_numgemm_1_bf16 --trace=cuda,nvtx,cudnn,cublas python benchmarks/linear/benchmark_grouped_linear.py --profile
# nsys profile --output=./benchmarks/linear/mkn_8192_8192_8192_numgemm_32_bf16 --trace=cuda,nvtx,cudnn,cublas python benchmarks/linear/benchmark_grouped_linear.py --profile
# nsys profile --output=./benchmarks/linear/mkn_4096_4096_4096_numgemm_8_fp8_sub_channel --trace=cuda,nvtx,cudnn,cublas python benchmarks/linear/benchmark_grouped_linear.py --profile
# nsys profile --output=./benchmarks/linear/mkn_8192_8192_8192_numgemm_2_fp8_sub_channel --trace=cuda,nvtx,cudnn,cublas python benchmarks/linear/benchmark_grouped_linear.py --profile
mkns = [(4096, 4096, 4096)] mkns = [(4096, 4096, 4096)]
recipe_list = ["fp8_sub_channel"] # in profile mode, only run one recipe specified in args.recipe
# recipe_list = ["bf16"] assert args.recipe != "all", (
"In profile mode, only one recipe can be specified, please specify the recipe as"
" fp8_sub_channel, mxfp8, or bf16"
)
recipe_list = [args.recipe]
num_gemms_list = [8] num_gemms_list = [8]
torch.autograd.profiler.emit_nvtx(record_shapes=True).__enter__() torch.autograd.profiler.emit_nvtx(record_shapes=True).__enter__()
...@@ -227,6 +264,18 @@ if __name__ == "__main__": ...@@ -227,6 +264,18 @@ if __name__ == "__main__":
for num_gemms in num_gemms_list: for num_gemms in num_gemms_list:
print(f"========== Benchmarking with num_gemms={num_gemms} ==========") print(f"========== Benchmarking with num_gemms={num_gemms} ==========")
for recipe_name in recipe_list: for recipe_name in recipe_list:
assert recipe_name in [
"bf16",
"fp8_sub_channel",
"mxfp8",
], "Recipe must be one of bf16, fp8_sub_channel, or mxfp8"
if recipe_name == "mxfp8" and not mxfp8_available:
print(f"MXFP8 is not available, skipping {recipe_name}")
continue
if recipe_name == "fp8_sub_channel" and not fp8_block_scaling_available:
print(f"FP8 block scaling is not available, skipping {recipe_name}")
continue
df = run_benchmark_linear( df = run_benchmark_linear(
mkns, mkns,
recipe_name, recipe_name,
......
...@@ -214,6 +214,8 @@ class MXFP8Quantizer : public Quantizer { ...@@ -214,6 +214,8 @@ class MXFP8Quantizer : public Quantizer {
std::pair<TensorWrapper, py::object> create_tensor( std::pair<TensorWrapper, py::object> create_tensor(
const std::vector<size_t>& shape, DType dtype, const std::vector<size_t>& shape, DType dtype,
std::optional<at::Tensor> rowwise_data = std::nullopt) const override; std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
std::vector<size_t> get_scale_shape(const std::vector<size_t>& shape, bool columnwise) const;
}; };
std::unique_ptr<Quantizer> convert_quantizer(py::handle quantizer); std::unique_ptr<Quantizer> convert_quantizer(py::handle quantizer);
......
...@@ -387,6 +387,162 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_fp ...@@ -387,6 +387,162 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_fp
return retval; return retval;
} }
std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mxfp8_tensors(
std::vector<std::vector<size_t>> &shape_list, std::vector<py::handle> &quantizer_py_list,
std::vector<MXFP8Quantizer *> &quantizer_cpp_list) {
init_extension();
std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> retval;
auto &tensor_py_list = std::get<0>(retval);
auto &tensor_cpp_list = std::get<1>(retval);
// Number of tensors
const size_t num_tensors = shape_list.size();
if (num_tensors == 0) {
return retval;
}
// Quantization parameters
const auto rowwise_usage = quantizer_cpp_list[0]->rowwise_usage;
const auto columnwise_usage = quantizer_cpp_list[0]->columnwise_usage;
const auto scaling_mode = quantizer_cpp_list[0]->get_scaling_mode();
const auto fp8_dtype = quantizer_cpp_list[0]->dtype;
constexpr size_t fp8_elem_size = 1;
constexpr size_t scale_elem_size = 1;
// Helper function to construct tensor view
// Note: Deleter holds a shared_ptr for the buffer, so the buffer
// will survive until all views are deleted.
auto make_torch_view = [](std::shared_ptr<at::Tensor> &buffer, const std::vector<size_t> &shape,
size_t offset, at::ScalarType dtype) -> at::Tensor {
std::vector<int64_t> shape_int64(shape.begin(), shape.end());
// in the case where full buffer is empty because local rank receives no tokens for all the experts
// then the data_ptr is nullptr, we need to return an empty tensor instead of calling from_blob
// but in the case where some experts receive tokens, some not, we want to leverage from_blob
// as much as possible to avoid CPU overhead
if (buffer->data_ptr<uint8_t>() == nullptr) {
return at::empty(shape_int64, at::device(at::kCUDA).dtype(dtype));
}
return at::from_blob(
buffer->data_ptr<uint8_t>() + offset, shape_int64,
[buffer](void *) {}, // deleter holds shared_ptr
at::device(at::kCUDA).dtype(dtype));
};
// Allocate row-wise data
std::vector<at::Tensor> rowwise_data_list, rowwise_scale_list;
std::vector<std::vector<size_t>> rowwise_data_shapes, rowwise_scale_shapes;
if (rowwise_usage) {
// Tensor sizes
for (size_t i = 0; i < num_tensors; ++i) {
rowwise_data_shapes.emplace_back(shape_list[i]);
rowwise_scale_shapes.emplace_back(
quantizer_cpp_list[i]->get_scale_shape(shape_list[i], false));
}
// Offsets in full buffer
size_t buffer_size = 0;
std::vector<size_t> data_offsets, scale_offsets;
for (size_t i = 0; i < num_tensors; ++i) {
buffer_size = roundup(buffer_size, 256); // align to 256B
data_offsets.push_back(buffer_size);
buffer_size += product(rowwise_data_shapes[i]) * fp8_elem_size;
}
for (size_t i = 0; i < num_tensors; ++i) {
buffer_size = roundup(buffer_size, 16); // align to 16B
scale_offsets.push_back(buffer_size);
buffer_size += product(rowwise_scale_shapes[i]) * scale_elem_size;
}
// Allocate full buffer
// TODO(zhongbo): use torch.empty if zero padding is added to the swizzle kernel
auto buffer = std::make_shared<at::Tensor>(
at::zeros({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)));
// auto buffer = std::make_shared<at::Tensor>(
// at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)));
// Construct tensor views
for (size_t i = 0; i < num_tensors; ++i) {
rowwise_data_list.emplace_back(
make_torch_view(buffer, rowwise_data_shapes[i], data_offsets[i], torch::kUInt8));
rowwise_scale_list.emplace_back(
make_torch_view(buffer, rowwise_scale_shapes[i], scale_offsets[i], torch::kUInt8));
}
}
// Allocate column-wise data
std::vector<at::Tensor> columnwise_data_list, columnwise_scale_list;
std::vector<std::vector<size_t>> columnwise_data_shapes, columnwise_scale_shapes;
if (columnwise_usage) {
// Tensor sizes
for (size_t i = 0; i < num_tensors; ++i) {
// For MXFP8, the columnwise data doesn't need transpose
// because of TN, NT, NN layout support in SM100
columnwise_data_shapes.emplace_back(shape_list[i]);
columnwise_scale_shapes.emplace_back(
quantizer_cpp_list[i]->get_scale_shape(shape_list[i], true));
}
// Offsets in full buffer
size_t buffer_size = 0;
std::vector<size_t> data_offsets, scale_offsets;
for (size_t i = 0; i < num_tensors; ++i) {
buffer_size = roundup(buffer_size, 256); // align to 256B
data_offsets.push_back(buffer_size);
buffer_size += product(columnwise_data_shapes[i]) * fp8_elem_size;
}
for (size_t i = 0; i < num_tensors; ++i) {
buffer_size = roundup(buffer_size, 16); // align to 16B
scale_offsets.push_back(buffer_size);
buffer_size += product(columnwise_scale_shapes[i]) * scale_elem_size;
}
// Allocate full buffer
// TODO(zhongbo): use torch.empty if zero padding is added to the swizzle kernel
auto buffer = std::make_shared<at::Tensor>(
at::zeros({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)));
// auto buffer = std::make_shared<at::Tensor>(
// at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)));
// Construct tensor views
for (size_t i = 0; i < num_tensors; ++i) {
columnwise_data_list.emplace_back(
make_torch_view(buffer, columnwise_data_shapes[i], data_offsets[i], torch::kUInt8));
columnwise_scale_list.emplace_back(
make_torch_view(buffer, columnwise_scale_shapes[i], scale_offsets[i], torch::kUInt8));
}
}
// Construct mxfp8 tensors
py::handle MXFP8TensorClass(reinterpret_cast<PyObject *>(MXFP8TensorBasePythonClass));
for (size_t i = 0; i < num_tensors; ++i) {
// Create tensor objects with proper reference counting
py::object rowwise_data = rowwise_usage ? py::cast(rowwise_data_list[i]) : py::none();
py::object rowwise_scale = rowwise_usage ? py::cast(rowwise_scale_list[i]) : py::none();
py::object columnwise_data =
(columnwise_usage ? py::cast(columnwise_data_list[i]) : py::none());
py::object columnwise_scale =
(columnwise_usage ? py::cast(columnwise_scale_list[i]) : py::none());
// Construct Python tensor
tensor_py_list.emplace_back(MXFP8TensorClass(rowwise_data, rowwise_scale, columnwise_data,
columnwise_scale, fp8_dtype,
quantizer_py_list[i]));
// Construct C++ tensor
tensor_cpp_list.emplace_back(makeTransformerEngineTensor(
rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr,
columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr,
rowwise_usage ? rowwise_data_shapes[i] : std::vector<size_t>{},
columnwise_usage ? columnwise_data_shapes[i] : std::vector<size_t>{}, fp8_dtype, nullptr,
nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr,
columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr,
rowwise_usage ? rowwise_scale_shapes[i] : std::vector<size_t>{},
columnwise_usage ? columnwise_scale_shapes[i] : std::vector<size_t>{}, scaling_mode));
}
return retval;
}
} // namespace } // namespace
std::vector<py::object> split_quantize(const at::Tensor &tensor, std::vector<py::object> split_quantize(const at::Tensor &tensor,
...@@ -441,9 +597,11 @@ std::vector<py::object> split_quantize(const at::Tensor &tensor, ...@@ -441,9 +597,11 @@ std::vector<py::object> split_quantize(const at::Tensor &tensor,
} }
// For FP8 block-scaling, we construct output tensors with bulk allocations // For FP8 block-scaling, we construct output tensors with bulk allocations
// For MXFP8, we also use bulk allocations
bool use_fused_bulk_alloc = true; bool use_fused_bulk_alloc = true;
for (size_t i = 0; i < quantizer_list.size(); i++) { for (size_t i = 0; i < quantizer_list.size(); i++) {
if (!detail::IsFloat8BlockwiseQuantizers(quantizer_list[i].ptr())) { if (!detail::IsFloat8BlockwiseQuantizers(quantizer_list[i].ptr()) &&
!detail::IsMXFP8Quantizers(quantizer_list[i].ptr())) {
use_fused_bulk_alloc = false; use_fused_bulk_alloc = false;
break; break;
} }
...@@ -461,13 +619,28 @@ std::vector<py::object> split_quantize(const at::Tensor &tensor, ...@@ -461,13 +619,28 @@ std::vector<py::object> split_quantize(const at::Tensor &tensor,
output_py_list.emplace_back(std::move(output_py)); output_py_list.emplace_back(std::move(output_py));
} }
} else { } else {
// FP8 block-scaling: construct output tensors with bulk allocations // TODO(zhongbo): make a better api to make this part less hacky
std::vector<Float8BlockQuantizer *> blockwise_quantizers; bool is_fp8_blockwise = detail::IsFloat8BlockwiseQuantizers(quantizer_list[0].ptr());
for (auto &quantizer : quantizer_cpp_list) { bool is_mxfp8 = detail::IsMXFP8Quantizers(quantizer_list[0].ptr());
blockwise_quantizers.push_back(static_cast<Float8BlockQuantizer *>(quantizer.get())); if (is_fp8_blockwise) {
// FP8 block-scaling: construct output tensors with bulk allocations
std::vector<Float8BlockQuantizer *> blockwise_quantizers;
for (auto &quantizer : quantizer_cpp_list) {
blockwise_quantizers.push_back(static_cast<Float8BlockQuantizer *>(quantizer.get()));
}
std::tie(output_py_list, output_cpp_list) =
bulk_allocate_fp8_blockwise_tensors(split_shapes, quantizer_list, blockwise_quantizers);
} else if (is_mxfp8) {
// MXFP8: construct output tensors with bulk allocations
std::vector<MXFP8Quantizer *> mxfp8_quantizers;
for (auto &quantizer : quantizer_cpp_list) {
mxfp8_quantizers.push_back(static_cast<MXFP8Quantizer *>(quantizer.get()));
}
std::tie(output_py_list, output_cpp_list) =
bulk_allocate_mxfp8_tensors(split_shapes, quantizer_list, mxfp8_quantizers);
} else {
NVTE_CHECK(false, "Expected either FP8 block-scaling or MXFP8 quantizer");
} }
std::tie(output_py_list, output_cpp_list) =
bulk_allocate_fp8_blockwise_tensors(split_shapes, quantizer_list, blockwise_quantizers);
} }
// Perform multi-tensor quantization // Perform multi-tensor quantization
......
...@@ -480,11 +480,6 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor( ...@@ -480,11 +480,6 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(
at::Tensor rowwise_data1, columnwise_data, rowwise_scale_inv, at::Tensor rowwise_data1, columnwise_data, rowwise_scale_inv,
columnwise_scale_inv; // TODO(pgadzinski) - change columnwise_scale_inv; // TODO(pgadzinski) - change
opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); opts = opts.dtype(torch::kUInt8).device(torch::kCUDA);
auto last_dim = static_cast<size_t>(torch_shape.back());
NVTE_CHECK(last_dim % MXFP8_BLOCK_SIZE == 0 && (numel / last_dim) % MXFP8_BLOCK_SIZE == 0,
"MXFP8 requires tensor dims that are divisble by ", MXFP8_BLOCK_SIZE,
" (got shape=", torch_shape, ")");
at::Tensor data; at::Tensor data;
if (rowwise_usage) { if (rowwise_usage) {
...@@ -493,9 +488,10 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor( ...@@ -493,9 +488,10 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(
} else { } else {
data = at::empty(torch_shape, opts); data = at::empty(torch_shape, opts);
} }
auto sinv0 = roundup(numel / last_dim, 128); auto scale_shape = get_scale_shape(shape, false);
auto sinv1 = roundup(last_dim / MXFP8_BLOCK_SIZE, 4); size_t sinv0 = scale_shape[0];
rowwise_scale_inv = at::zeros({sinv0, sinv1}, opts); size_t sinv1 = scale_shape[1];
rowwise_scale_inv = at::zeros({static_cast<int64_t>(sinv0), static_cast<int64_t>(sinv1)}, opts);
tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape); tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape);
tensor.set_rowwise_scale_inv( tensor.set_rowwise_scale_inv(
rowwise_scale_inv.data_ptr(), DType::kFloat8E8M0, rowwise_scale_inv.data_ptr(), DType::kFloat8E8M0,
...@@ -503,10 +499,12 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor( ...@@ -503,10 +499,12 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(
} }
if (columnwise_usage) { if (columnwise_usage) {
auto sinv0 = roundup(numel / (last_dim * MXFP8_BLOCK_SIZE), 4); auto scale_shape = get_scale_shape(shape, true);
auto sinv1 = roundup(last_dim, 128); size_t sinv0 = scale_shape[0];
size_t sinv1 = scale_shape[1];
columnwise_data = at::empty(torch_shape, opts); columnwise_data = at::empty(torch_shape, opts);
columnwise_scale_inv = at::zeros({sinv0, sinv1}, opts); columnwise_scale_inv =
at::zeros({static_cast<int64_t>(sinv0), static_cast<int64_t>(sinv1)}, opts);
tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, shape); tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, shape);
tensor.set_columnwise_scale_inv( tensor.set_columnwise_scale_inv(
...@@ -534,4 +532,35 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor( ...@@ -534,4 +532,35 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(
return {std::move(tensor), std::move(ret)}; return {std::move(tensor), std::move(ret)};
} }
std::vector<size_t> MXFP8Quantizer::get_scale_shape(const std::vector<size_t>& shape,
bool columnwise) const {
size_t numel = 1;
for (auto s : shape) {
numel *= s;
}
auto last_dim = shape.back();
NVTE_CHECK(last_dim % MXFP8_BLOCK_SIZE == 0 && (numel / last_dim) % MXFP8_BLOCK_SIZE == 0,
"MXFP8 requires tensor dims that are divisble by ", MXFP8_BLOCK_SIZE,
" (got shape=", shape, ")");
std::vector<size_t> scale_shape;
bool rowwise_usage = !columnwise;
if (rowwise_usage) {
// rowwise scaling factor shape
size_t sinv0 = roundup(numel / last_dim, 128);
size_t sinv1 = roundup(last_dim / MXFP8_BLOCK_SIZE, 4);
scale_shape = {sinv0, sinv1};
} else {
// columnwise scaling factor shape
size_t sinv0 = roundup(numel / (last_dim * MXFP8_BLOCK_SIZE), 4);
size_t sinv1 = roundup(last_dim, 128);
scale_shape = {sinv0, sinv1};
}
return scale_shape;
}
} // namespace transformer_engine::pytorch } // namespace transformer_engine::pytorch
...@@ -68,13 +68,13 @@ class MXFP8TensorBase(QuantizedTensorBase): ...@@ -68,13 +68,13 @@ class MXFP8TensorBase(QuantizedTensorBase):
def __new__( def __new__(
cls, cls,
*args,
rowwise_data: Optional[torch.Tensor], rowwise_data: Optional[torch.Tensor],
rowwise_scale_inv: torch.Tensor, rowwise_scale_inv: Optional[torch.Tensor],
columnwise_data: Optional[torch.Tensor], columnwise_data: Optional[torch.Tensor],
columnwise_scale_inv: torch.Tensor, columnwise_scale_inv: Optional[torch.Tensor],
fp8_dtype: TE_DType, fp8_dtype: TE_DType,
quantizer: Optional[Quantizer] = None, quantizer: Optional[Quantizer],
*args,
**kwargs, **kwargs,
): ):
if cls is MXFP8TensorBase: if cls is MXFP8TensorBase:
......
...@@ -165,6 +165,32 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor): ...@@ -165,6 +165,32 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
""" """
# NOTE: We reorder the *args so that we can instantiate a MXFP8TensorBase with positional args,
# which significantly reduces the Pybind11 overhead when calling the constructor from C++.
def __new__(
cls,
*args,
rowwise_data: Optional[torch.Tensor],
rowwise_scale_inv: Optional[torch.Tensor],
columnwise_data: Optional[torch.Tensor],
columnwise_scale_inv: Optional[torch.Tensor],
fp8_dtype: TE_DType,
quantizer: Optional[Quantizer],
**kwargs,
):
instance = super().__new__(
cls,
rowwise_data,
rowwise_scale_inv,
columnwise_data,
columnwise_scale_inv,
fp8_dtype,
quantizer,
*args,
**kwargs,
)
return instance
def __repr__(self, *, tensor_contents=None): def __repr__(self, *, tensor_contents=None):
return f"MXFP8Tensor(fp8_dtype={self._fp8_dtype}, data={self.dequantize(dtype=self.dtype)})" return f"MXFP8Tensor(fp8_dtype={self._fp8_dtype}, data={self.dequantize(dtype=self.dtype)})"
...@@ -302,6 +328,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor): ...@@ -302,6 +328,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
fp8_dtype: TE_DType, fp8_dtype: TE_DType,
dtype: torch.dtype, dtype: torch.dtype,
shape: torch.shape, shape: torch.shape,
quantizer: Quantizer,
) -> MXFP8Tensor: ) -> MXFP8Tensor:
"""Build MXFP8Tensor, for use in __reduce__ """Build MXFP8Tensor, for use in __reduce__
...@@ -317,6 +344,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor): ...@@ -317,6 +344,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
columnwise_scale_inv=columnwise_scale_inv, columnwise_scale_inv=columnwise_scale_inv,
dtype=dtype, dtype=dtype,
shape=shape, shape=shape,
quantizer=quantizer,
) )
def __reduce_ex__(self, protocol: int) -> tuple: def __reduce_ex__(self, protocol: int) -> tuple:
...@@ -331,6 +359,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor): ...@@ -331,6 +359,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
self._fp8_dtype, self._fp8_dtype,
self.dtype, self.dtype,
self.shape, self.shape,
self._quantizer,
), ),
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment