Unverified Commit c30e961f authored by Zhongbo Zhu's avatar Zhongbo Zhu Committed by GitHub
Browse files

[PyTorch][MoE] Reduce CPU Overhead By Fuse Torch Empty Calls (#1793)



* finish python ref impl for bulk alloc
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* c++ bulk alloc worked, still draft version
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* clean up
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* resolve rebase conflict
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add license
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* use shared_ptr to auto manage reference count
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* attempt to fix misc training error
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* attempt to handle case where experts get zero token
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* updated with fused C++ function calls
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* clean up
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* experiment with reducing py object construction time
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* fix seg fault bug in inference mode
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* fix lint
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* fuse torch split into bulk alloc
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* clean up
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* rebase to latest main
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* fix unit test failure
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* fix lint error
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* refactor create_tensor to use get_scale_shape
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* refactor quantize to call quantize_cpp
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* Implement separate functions for multi-tensor quantize and split + multi-tensor quantize
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Update grouped linear module with fused split+quantize func
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Move multi-tensor quantize func to cast.cpp
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Do not expose quantizer helper function externally
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Revert cuDNN frontend commit
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* fix corner cases with zero tokens
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* add comments
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

---------
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
parent 7db72dbc
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import argparse
import torch
import torch.utils.benchmark as benchmark
import pandas as pd
import pathlib
from transformer_engine.pytorch.module import GroupedLinear
from transformer_engine.common.recipe import Float8BlockScaling
from transformer_engine.pytorch.fp8 import fp8_autocast
from contextlib import nullcontext
RECIPES = {
"bf16": None,
"fp8_sub_channel": Float8BlockScaling(),
}
def run_linear_multiple_steps(layer, x, m_splits, mode, gradient, run_num_steps=1, recipe=None):
assert mode in ["fwd_only", "fwd_bwd"]
fp8_context = (
fp8_autocast(enabled=True, fp8_recipe=recipe) if recipe is not None else nullcontext()
)
# print(f"fp8_context: {fp8_context} and is it nullcontext? {isinstance(fp8_context, nullcontext)}")
if mode == "fwd_only":
with torch.no_grad(), fp8_context:
for i in range(run_num_steps):
y_q = layer.forward(
x,
m_splits,
is_first_microbatch=(i == 0),
)
return y_q
else:
# reset gradients
layer.zero_grad()
x.grad = None
with fp8_context:
for i in range(run_num_steps):
label = f"step_{i}"
torch.cuda.nvtx.range_push(label)
y_q = layer.forward(
x,
m_splits,
is_first_microbatch=(i == 0),
)
y_q.backward(gradient)
torch.cuda.nvtx.range_pop()
grads_q = []
grads_q.append(x.grad)
# remaining derivatives are in respect to model parameters
for p in layer.parameters():
if p.requires_grad:
grads_q.append(p.grad)
return y_q, grads_q
def benchmark_linear(
x,
ws,
m_splits,
bias,
recipe_name,
mode,
num_gemms=4,
):
params_dtype = torch.bfloat16
recipe = RECIPES[recipe_name]
in_features = x.shape[1]
out_features = ws[0].shape[0]
gradient = torch.ones((x.shape[0], out_features), dtype=torch.bfloat16, device=x.device)
layer = GroupedLinear(
num_gemms,
in_features,
out_features,
bias=bias is not None,
params_dtype=params_dtype,
)
layer = layer.to("cuda")
with torch.no_grad():
for i in range(num_gemms):
weight_i = getattr(layer, f"weight{i}")
weight_i.copy_(ws[i])
if bias is not None:
bias_i = getattr(layer, f"bias{i}")
bias_i.copy_(bias)
num_microbatches = 32
label = f"{recipe_name}_{'grouped'}"
torch.cuda.nvtx.range_push(label)
timing = benchmark.Timer(
stmt=(
"run_linear_multiple_steps(layer, x, m_splits, mode, gradient, num_microbatches,"
" recipe)"
),
globals={
"run_linear_multiple_steps": run_linear_multiple_steps,
"layer": layer,
"x": x,
"m_splits": m_splits,
"mode": mode,
"gradient": gradient,
"num_microbatches": num_microbatches,
"recipe": recipe,
},
num_threads=1,
).blocked_autorange(min_run_time=5)
print(f"{recipe_name}: {timing} \n")
timing_ms = timing.median * 1000 / num_microbatches
return timing_ms
def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4):
data = []
assert not use_bias, "Bias is not supported for GroupedLinear benchmark"
print(f"========== Benchmarking {recipe_name} ==========")
for m, k, n in mkns:
device = "cuda"
x = torch.randn((m, k), dtype=torch.bfloat16, device=device, requires_grad=True)
ws = [torch.randn((n, k), dtype=torch.bfloat16, device=device) for _ in range(num_gemms)]
assert m % num_gemms == 0
m_splits = [m // num_gemms] * num_gemms
# Bias is not supported for GroupedLinear benchmark
bias = None
# Run the benchmark
print(f"fwd_m={m}, fwd_k={k}, fwd_n={n}")
grouped_fwd_bwd_timing_ms = benchmark_linear(
x,
ws,
m_splits,
bias,
recipe_name,
mode="fwd_bwd",
num_gemms=num_gemms,
)
# Append the results
data.append(
[
m,
k,
n,
recipe_name,
num_gemms,
grouped_fwd_bwd_timing_ms,
]
)
df = pd.DataFrame(
data=data,
columns=[
"m",
"k",
"n",
"recipe",
"num_gemms",
"grouped_fwd_bwd_time_ms",
],
)
print(df, "\n")
return df
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--profile", action="store_true", help="Enable profiling mode")
parser.add_argument(
"--output_dir",
type=str,
default="benchmark_output/",
help="output path for report",
)
args = parser.parse_args()
use_bias = False
# Set the MKN values to benchmark
mkns = []
for m in [1024]:
# for m in [4096, 8192, 16384]:
# for n in [1024, 2048, 4096, 8192, 16384]:
for n in [3072]:
for k in [4096]:
mkns.append((m, k, n))
# recipe_list = [
# "bf16", "fp8_sub_channel",
# ]
recipe_list = [
"fp8_sub_channel",
]
# num_gemms_list = [16, 32]
num_gemms_list = [4]
if args.profile:
# nsys profile --output=./benchmarks/linear/mkn_4096_4096_4096_numgemm_1_bf16 --trace=cuda,nvtx,cudnn,cublas python benchmarks/linear/benchmark_grouped_linear.py --profile
# nsys profile --output=./benchmarks/linear/mkn_8192_8192_8192_numgemm_32_bf16 --trace=cuda,nvtx,cudnn,cublas python benchmarks/linear/benchmark_grouped_linear.py --profile
# nsys profile --output=./benchmarks/linear/mkn_4096_4096_4096_numgemm_8_fp8_sub_channel --trace=cuda,nvtx,cudnn,cublas python benchmarks/linear/benchmark_grouped_linear.py --profile
# nsys profile --output=./benchmarks/linear/mkn_8192_8192_8192_numgemm_2_fp8_sub_channel --trace=cuda,nvtx,cudnn,cublas python benchmarks/linear/benchmark_grouped_linear.py --profile
mkns = [(4096, 4096, 4096)]
recipe_list = ["fp8_sub_channel"]
# recipe_list = ["bf16"]
num_gemms_list = [8]
torch.autograd.profiler.emit_nvtx(record_shapes=True).__enter__()
# Initialize a dataframe to store the results
df_linears = pd.DataFrame()
# Run the fp8 benchmarks
for num_gemms in num_gemms_list:
print(f"========== Benchmarking with num_gemms={num_gemms} ==========")
for recipe_name in recipe_list:
df = run_benchmark_linear(
mkns,
recipe_name,
use_bias,
num_gemms=num_gemms,
)
df_linears = pd.concat([df_linears, df])
print(df_linears)
if args.profile:
torch.autograd.profiler.emit_nvtx().__exit__(None, None, None)
...@@ -197,6 +197,8 @@ class Float8BlockQuantizer : public Quantizer { ...@@ -197,6 +197,8 @@ class Float8BlockQuantizer : public Quantizer {
std::pair<TensorWrapper, py::object> create_tensor( std::pair<TensorWrapper, py::object> create_tensor(
const std::vector<size_t>& shape, DType dtype, const std::vector<size_t>& shape, DType dtype,
std::optional<at::Tensor> rowwise_data = std::nullopt) const override; std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
std::vector<size_t> get_scale_shape(const std::vector<size_t>& shape, bool columnwise) const;
}; };
class MXFP8Quantizer : public Quantizer { class MXFP8Quantizer : public Quantizer {
......
...@@ -108,10 +108,6 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm( ...@@ -108,10 +108,6 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
* Transpose * Transpose
**************************************************************************************************/ **************************************************************************************************/
std::vector<py::object> fused_multi_quantize(std::vector<at::Tensor> input_list,
std::optional<std::vector<py::object>> output_list,
std::vector<py::handle> quantizer_list, DType otype);
at::Tensor fp8_transpose(at::Tensor input, DType otype, at::Tensor fp8_transpose(at::Tensor input, DType otype,
std::optional<at::Tensor> output = std::nullopt); std::optional<at::Tensor> output = std::nullopt);
...@@ -182,10 +178,17 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w ...@@ -182,10 +178,17 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
**************************************************************************************************/ **************************************************************************************************/
py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::object &output, py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::object &output,
std::optional<at::Tensor> noop); std::optional<at::Tensor> noop_flag);
py::object dequantize(const py::handle &input, DType otype); py::object dequantize(const py::handle &input, DType otype);
std::vector<py::object> multi_tensor_quantize(const std::vector<at::Tensor> &tensor_list,
std::vector<py::handle> quantizer_list);
std::vector<py::object> split_quantize(const at::Tensor &tensor,
const std::vector<int> &split_sections,
std::vector<py::handle> quantizer_list);
/*************************************************************************************************** /***************************************************************************************************
* Bias gradient fusions * Bias gradient fusions
**************************************************************************************************/ **************************************************************************************************/
......
...@@ -12,7 +12,9 @@ ...@@ -12,7 +12,9 @@
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include <stdexcept> #include <memory>
#include <optional>
#include <vector>
#include "../common.h" #include "../common.h"
#include "../extensions.h" #include "../extensions.h"
...@@ -199,10 +201,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -199,10 +201,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("weight"), py::arg("eps"), py::arg("ln_out"), py::arg("quantizer"), py::arg("weight"), py::arg("eps"), py::arg("ln_out"), py::arg("quantizer"),
py::arg("otype"), py::arg("sm_margin"), py::arg("zero_centered_gamma")); py::arg("otype"), py::arg("sm_margin"), py::arg("zero_centered_gamma"));
m.def("rmsnorm_bwd", &transformer_engine::pytorch::rmsnorm_bwd, "Backward of RMSNorm"); m.def("rmsnorm_bwd", &transformer_engine::pytorch::rmsnorm_bwd, "Backward of RMSNorm");
m.def("fused_multi_quantize", &transformer_engine::pytorch::fused_multi_quantize, m.def("multi_tensor_quantize", &transformer_engine::pytorch::multi_tensor_quantize,
"Fused Multi-tensor Cast + Transpose", py::arg("input_list"), py::arg("output_list"), "Multi-tensor quantize", py::arg("tensor_list"), py::arg("quantizer_list"));
py::arg("quantizer_list"), py::arg("otype")); m.def("split_quantize", &transformer_engine::pytorch::split_quantize,
"Split and multi-tensor quantize", py::arg("tensor"), py::arg("split_sections"),
py::arg("quantizer_list"));
m.def("te_general_grouped_gemm", &transformer_engine::pytorch::te_general_grouped_gemm, m.def("te_general_grouped_gemm", &transformer_engine::pytorch::te_general_grouped_gemm,
"Grouped GEMM"); "Grouped GEMM");
m.def("fp8_transpose", &transformer_engine::pytorch::fp8_transpose, "Transpose with FP8 I/O", m.def("fp8_transpose", &transformer_engine::pytorch::fp8_transpose, "Transpose with FP8 I/O",
......
...@@ -4,80 +4,16 @@ ...@@ -4,80 +4,16 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <pybind.h>
#include <optional> #include <optional>
#include <vector>
#include "../extensions.h" #include "../extensions.h"
#include "pybind.h" #include "pybind.h"
namespace transformer_engine::pytorch { namespace transformer_engine {
namespace pytorch {
std::vector<py::object> fused_multi_quantize(std::vector<at::Tensor> input_list,
std::optional<std::vector<py::object>> output_list,
std::vector<py::handle> quantizer_list, DType otype) {
init_extension();
std::vector<NVTETensor> nvte_tensor_input_list;
std::vector<NVTETensor> nvte_tensor_output_list;
std::vector<py::object> py_output_objects_list;
std::vector<TensorWrapper> tensor_wrappers;
if (output_list.has_value()) {
py_output_objects_list = output_list.value();
}
// Choose implementation
// Note: Currently only have fused kernel for FP8 cast-transpose
bool with_fused_kernel = true;
// create TE tensors from input
for (size_t i = 0; i < input_list.size(); i++) {
auto input_tensor = makeTransformerEngineTensor(input_list[i]);
const NVTEShape input_shape = input_tensor.shape();
TensorWrapper output_tensor;
if (!detail::IsFloat8Quantizers(quantizer_list[i].ptr())) {
with_fused_kernel = false;
}
if (output_list == std::nullopt) {
std::unique_ptr<Quantizer> quantizer = convert_quantizer(quantizer_list[i]);
std::vector<size_t> output_shape(input_shape.data, input_shape.data + input_shape.ndim);
py::object o;
std::tie(output_tensor, o) = quantizer->create_tensor(output_shape, otype);
py_output_objects_list.push_back(o);
} else {
output_tensor = makeTransformerEngineTensor((*output_list)[i], quantizer_list[i]);
}
if (input_tensor.numel() == 0) continue;
nvte_tensor_output_list.emplace_back(output_tensor.data());
nvte_tensor_input_list.emplace_back(input_tensor.data());
tensor_wrappers.emplace_back(std::move(input_tensor));
tensor_wrappers.emplace_back(std::move(output_tensor));
}
// Check tensor lists
NVTE_CHECK(nvte_tensor_output_list.size() == nvte_tensor_input_list.size(),
"Number of input and output tensors must match");
for (size_t i = 0; i < nvte_tensor_output_list.size(); i++) {
if (nvte_tensor_columnwise_data(nvte_tensor_output_list[i]) == nullptr) {
with_fused_kernel = false;
break;
}
}
// Launch TE kernel
if (with_fused_kernel) {
NVTE_SCOPED_GIL_RELEASE({
nvte_multi_cast_transpose(nvte_tensor_input_list.size(), nvte_tensor_input_list.data(),
nvte_tensor_output_list.data(), at::cuda::getCurrentCUDAStream());
});
} else {
for (size_t i = 0; i < py_output_objects_list.size(); i++) {
quantize(input_list[i], quantizer_list[i], py_output_objects_list[i], std::nullopt);
}
}
return py_output_objects_list;
}
at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional<at::Tensor> output) { at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional<at::Tensor> output) {
init_extension(); init_extension();
...@@ -108,4 +44,5 @@ at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional<at::Tensor ...@@ -108,4 +44,5 @@ at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional<at::Tensor
return out; return out;
} }
} // namespace transformer_engine::pytorch } // namespace pytorch
} // namespace transformer_engine
...@@ -283,10 +283,8 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor( ...@@ -283,10 +283,8 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
const std::vector<size_t>& shape, DType dtype, std::optional<at::Tensor> rowwise_data) const { const std::vector<size_t>& shape, DType dtype, std::optional<at::Tensor> rowwise_data) const {
using namespace pybind11::literals; using namespace pybind11::literals;
std::vector<int64_t> torch_shape; std::vector<int64_t> torch_shape;
size_t numel = 1;
for (auto s : shape) { for (auto s : shape) {
torch_shape.emplace_back(static_cast<int64_t>(s)); torch_shape.emplace_back(static_cast<int64_t>(s));
numel *= s;
} }
TensorWrapper tensor(this->get_scaling_mode()); TensorWrapper tensor(this->get_scaling_mode());
...@@ -296,10 +294,6 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor( ...@@ -296,10 +294,6 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); opts = opts.dtype(torch::kUInt8).device(torch::kCUDA);
scale_opts = scale_opts.dtype(torch::kFloat32).device(torch::kCUDA); scale_opts = scale_opts.dtype(torch::kFloat32).device(torch::kCUDA);
size_t k_dim = torch_shape.size() == 0 ? 1u : torch_shape.back();
size_t m_dim = numel / k_dim;
constexpr size_t kBlockLen = 128;
Float8BlockScaleTensorFormat data_format = Float8BlockScaleTensorFormat data_format =
(all_gather_usage ? Float8BlockScaleTensorFormat::COMPACT (all_gather_usage ? Float8BlockScaleTensorFormat::COMPACT
: Float8BlockScaleTensorFormat::GEMM_READY); : Float8BlockScaleTensorFormat::GEMM_READY);
...@@ -310,30 +304,9 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor( ...@@ -310,30 +304,9 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
} else { } else {
data_rowwise = at::empty(torch_shape, opts); data_rowwise = at::empty(torch_shape, opts);
} }
size_t sinv0 = 0; auto scale_shape = get_scale_shape(shape, false);
size_t sinv1 = 0; size_t sinv0 = scale_shape[0];
if (block_scaling_dim == 2) { size_t sinv1 = scale_shape[1];
// 2D scaling is always GEMM_READY for now
NVTE_CHECK(data_format == Float8BlockScaleTensorFormat::GEMM_READY,
"2D scaling is always GEMM_READY for now.");
sinv0 = (m_dim + kBlockLen - 1) / kBlockLen;
sinv1 = roundup((k_dim + kBlockLen - 1) / kBlockLen, 4);
} else if (block_scaling_dim == 1) {
// 1D scaling can be GEMM_READY or COMPACT
bool rowwise_compact = data_format == Float8BlockScaleTensorFormat::COMPACT;
// default rowwise scaling factor shape already transpose the scaling factor so it's GEMM_READY
sinv0 = (k_dim + kBlockLen - 1) / kBlockLen;
sinv1 = rowwise_compact ? m_dim : roundup(m_dim, 4);
// if the rowwise format is compact, the scaling factor is not be transposed
if (rowwise_compact) {
std::swap(sinv0, sinv1);
}
} else {
NVTE_ERROR(
"Unsupported block_scaling_dim in create_tensor rowwise. "
"Expected 1 or 2. Got ",
block_scaling_dim);
}
scale_inv_rowwise = scale_inv_rowwise =
at::empty({static_cast<int64_t>(sinv0), static_cast<int64_t>(sinv1)}, scale_opts); at::empty({static_cast<int64_t>(sinv0), static_cast<int64_t>(sinv1)}, scale_opts);
tensor.set_rowwise_data(data_rowwise.data_ptr(), this->dtype, shape); tensor.set_rowwise_data(data_rowwise.data_ptr(), this->dtype, shape);
...@@ -364,27 +337,9 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor( ...@@ -364,27 +337,9 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
columnwise_shape = shape; columnwise_shape = shape;
} }
} }
size_t sinv0 = 0; auto scale_shape = get_scale_shape(shape, true);
size_t sinv1 = 0; size_t sinv0 = scale_shape[0];
if (block_scaling_dim == 2) { size_t sinv1 = scale_shape[1];
// 2D scaling is always GEMM_READY for now
NVTE_CHECK(data_format == Float8BlockScaleTensorFormat::GEMM_READY,
"2D scaling is always GEMM_READY for now.");
sinv0 = (k_dim + kBlockLen - 1) / kBlockLen;
sinv1 = roundup((m_dim + kBlockLen - 1) / kBlockLen, 4);
} else if (block_scaling_dim == 1) {
bool columnwise_compact = data_format == Float8BlockScaleTensorFormat::COMPACT;
sinv0 = (m_dim + kBlockLen - 1) / kBlockLen;
sinv1 = columnwise_compact ? k_dim : roundup(k_dim, 4);
// GEMM READY case: scaling factor is [sinv0, sinv1], already transposed here for CuBLAS
// for COMPACT case, since we apply 128x1 scaling here without transposing columnwise data, scaling factor is also [sinv0, sinv1]
// so no need to swap sinv0 and sinv1 here
} else {
NVTE_ERROR(
"Unsupported block_scaling_dim in create_tensor columnwise. "
"Expected 1 or 2. Got ",
block_scaling_dim);
}
data_colwise = at::empty(torch_columnwise_shape, opts); data_colwise = at::empty(torch_columnwise_shape, opts);
scale_inv_colwise = scale_inv_colwise =
at::empty({static_cast<int64_t>(sinv0), static_cast<int64_t>(sinv1)}, scale_opts); at::empty({static_cast<int64_t>(sinv0), static_cast<int64_t>(sinv1)}, scale_opts);
...@@ -418,6 +373,81 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor( ...@@ -418,6 +373,81 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
return {std::move(tensor), std::move(ret)}; return {std::move(tensor), std::move(ret)};
} }
std::vector<size_t> Float8BlockQuantizer::get_scale_shape(const std::vector<size_t>& shape,
bool columnwise) const {
size_t numel = 1;
for (auto s : shape) {
numel *= s;
}
size_t k_dim = shape.size() == 0 ? 1u : shape.back();
size_t m_dim = numel / k_dim;
constexpr size_t kBlockLen = 128;
Float8BlockScaleTensorFormat data_format =
(all_gather_usage ? Float8BlockScaleTensorFormat::COMPACT
: Float8BlockScaleTensorFormat::GEMM_READY);
std::vector<size_t> scale_shape;
bool rowwise_usage = !columnwise;
if (rowwise_usage) {
// rowwise scaling factor shape
size_t sinv0 = 0;
size_t sinv1 = 0;
if (block_scaling_dim == 2) {
// 2D scaling is always GEMM_READY for now
NVTE_CHECK(data_format == Float8BlockScaleTensorFormat::GEMM_READY,
"2D scaling is always GEMM_READY for now.");
sinv0 = (m_dim + kBlockLen - 1) / kBlockLen;
sinv1 = roundup((k_dim + kBlockLen - 1) / kBlockLen, 4);
} else if (block_scaling_dim == 1) {
// 1D scaling can be GEMM_READY or COMPACT
bool rowwise_compact = data_format == Float8BlockScaleTensorFormat::COMPACT;
// default rowwise scaling factor shape already transpose the scaling factor so it's GEMM_READY
sinv0 = (k_dim + kBlockLen - 1) / kBlockLen;
sinv1 = rowwise_compact ? m_dim : roundup(m_dim, 4);
// if the rowwise format is compact, the scaling factor is not be transposed
if (rowwise_compact) {
std::swap(sinv0, sinv1);
}
} else {
NVTE_CHECK(false,
"Unsupported block_scaling_dim in create_tensor rowwise."
"Expected 1 or 2. Got ",
block_scaling_dim);
}
scale_shape = {sinv0, sinv1};
} else {
// columnwise scaling factor shape
size_t sinv0 = 0;
size_t sinv1 = 0;
if (block_scaling_dim == 2) {
// 2D scaling is always GEMM_READY for now
NVTE_CHECK(data_format == Float8BlockScaleTensorFormat::GEMM_READY,
"2D scaling is always GEMM_READY for now.");
sinv0 = (k_dim + kBlockLen - 1) / kBlockLen;
sinv1 = roundup((m_dim + kBlockLen - 1) / kBlockLen, 4);
} else if (block_scaling_dim == 1) {
// 1D scaling can be GEMM_READY or COMPACT
bool columnwise_compact = data_format == Float8BlockScaleTensorFormat::COMPACT;
sinv0 = (m_dim + kBlockLen - 1) / kBlockLen;
sinv1 = columnwise_compact ? k_dim : roundup(k_dim, 4);
// GEMM READY case: scaling factor is [sinv0, sinv1], already transposed here for CuBLAS
// for COMPACT case, since we apply 128x1 scaling here without transposing columnwise data, scaling factor is also [sinv0, sinv1]
// so no need to swap sinv0 and sinv1 here
} else {
NVTE_CHECK(false,
"Unsupported block_scaling_dim in create_tensor columnwise."
"Expected 1 or 2. Got ",
block_scaling_dim);
}
scale_shape = {sinv0, sinv1};
}
return scale_shape;
}
MXFP8Quantizer::MXFP8Quantizer(const py::handle& quantizer) : Quantizer(quantizer) { MXFP8Quantizer::MXFP8Quantizer(const py::handle& quantizer) : Quantizer(quantizer) {
this->dtype = quantizer.attr("dtype").cast<DType>(); this->dtype = quantizer.attr("dtype").cast<DType>();
} }
......
...@@ -24,7 +24,6 @@ from ..fp8 import FP8GlobalStateManager ...@@ -24,7 +24,6 @@ from ..fp8 import FP8GlobalStateManager
from ..utils import ( from ..utils import (
divide, divide,
cast_if_needed, cast_if_needed,
assert_dim_for_fp8_exec,
clear_tensor_data, clear_tensor_data,
init_method_constant, init_method_constant,
requires_grad, requires_grad,
...@@ -38,7 +37,7 @@ from ..distributed import ( ...@@ -38,7 +37,7 @@ from ..distributed import (
from ..cpp_extensions import ( from ..cpp_extensions import (
general_grouped_gemm, general_grouped_gemm,
) )
from ..constants import GemmParallelModes, dist_group_type, TE_DType from ..constants import GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing from ..graph import is_graph_capturing
from ..cpu_offload import is_cpu_offload_enabled from ..cpu_offload import is_cpu_offload_enabled
...@@ -87,20 +86,9 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -87,20 +86,9 @@ class _GroupedLinear(torch.autograd.Function):
weights = weights_and_biases[:num_gemms] weights = weights_and_biases[:num_gemms]
biases = weights_and_biases[num_gemms:] biases = weights_and_biases[num_gemms:]
device = inp.device device = inp.device
# Make sure input dimensions are compatible
in_features = weights[0].shape[-1]
assert inp.shape[-1] == in_features, "GEMM not possible"
inputmats = torch.split(inp.view(-1, in_features), m_splits)
if fp8:
assert_dim_for_fp8_exec(*inputmats, *weights)
# Cast input to expected dtype
inputmats_no_fp8 = [cast_if_needed(mat, activation_dtype) for mat in inputmats]
inputmats = []
weight_requires_grad = weights[0].requires_grad weight_requires_grad = weights[0].requires_grad
# Configure quantizers
if input_quantizers[0] is not None: if input_quantizers[0] is not None:
for input_quantizer in input_quantizers: for input_quantizer in input_quantizers:
input_quantizer.set_usage( input_quantizer.set_usage(
...@@ -120,17 +108,25 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -120,17 +108,25 @@ class _GroupedLinear(torch.autograd.Function):
for output_quantizer in output_quantizers: for output_quantizer in output_quantizers:
output_quantizer.set_usage(rowwise=True, columnwise=False) output_quantizer.set_usage(rowwise=True, columnwise=False)
fprop_gemm_use_split_accumulator = _2X_ACC_FPROP # Initialize input tensors
if fp8: in_features = weights[0].size(-1)
recipe = FP8GlobalStateManager.get_fp8_recipe() if inp.size(-1) != in_features:
if hasattr(recipe, "fp8_gemm_fprop"): raise ValueError(
fprop_gemm_use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator f"Input tensor (shape={tuple(inp.size())}) is not compatible with "
inputmats = tex.fused_multi_quantize( f"weight tensor (shape={tuple(weights[0].size())})"
inputmats_no_fp8, None, input_quantizers, TE_DType[activation_dtype]
) )
weights_fp8 = [] inp_view = inp.reshape(-1, in_features)
bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype inputmats: list
if fp8:
inputmats = tex.split_quantize(inp_view, m_splits, input_quantizers)
else:
inputmats = torch.split(cast_if_needed(inp_view, activation_dtype), m_splits)
# Initialize weights
weights_fp8: list
if fp8:
# FP8 cast to workspace buffer # FP8 cast to workspace buffer
weights_fp8 = []
update_workspace = is_first_microbatch is None or is_first_microbatch update_workspace = is_first_microbatch is None or is_first_microbatch
for i in range(num_gemms): for i in range(num_gemms):
weight_fp8 = module.get_weight_workspace( weight_fp8 = module.get_weight_workspace(
...@@ -143,18 +139,29 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -143,18 +139,29 @@ class _GroupedLinear(torch.autograd.Function):
weights_fp8.append(weight_fp8) weights_fp8.append(weight_fp8)
else: else:
inputmats = inputmats_no_fp8
bias_dtype = activation_dtype
weights_fp8 = [cast_if_needed(weight, activation_dtype) for weight in weights] weights_fp8 = [cast_if_needed(weight, activation_dtype) for weight in weights]
# Initialize biases
bias_dtype = activation_dtype
if fp8 and activation_dtype == torch.float32:
bias_dtype = torch.bfloat16 # FP8 GEMM only supports BF16/FP16 bias
biases = [cast_if_needed(bias, bias_dtype) for bias in biases] if use_bias else biases biases = [cast_if_needed(bias, bias_dtype) for bias in biases] if use_bias else biases
# Initialize output tensor
out = torch.empty( out = torch.empty(
[sum(m_splits), weights_fp8[0].size(0)], [sum(m_splits), weights_fp8[0].size(0)],
dtype=activation_dtype, dtype=activation_dtype,
device=device, device=device,
) )
# Choose whether to use split accumulator
use_split_accumulator = _2X_ACC_FPROP
if fp8:
recipe = FP8GlobalStateManager.get_fp8_recipe()
if hasattr(recipe, "fp8_gemm_fprop"):
use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator
# Perform GEMM
_ = general_grouped_gemm( _ = general_grouped_gemm(
weights_fp8, weights_fp8,
inputmats, inputmats,
...@@ -165,7 +172,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -165,7 +172,7 @@ class _GroupedLinear(torch.autograd.Function):
m_splits=m_splits, m_splits=m_splits,
bias=biases, bias=biases,
use_bias=use_bias, use_bias=use_bias,
use_split_accumulator=fprop_gemm_use_split_accumulator, use_split_accumulator=use_split_accumulator,
) )
if fp8_calibration: if fp8_calibration:
...@@ -247,36 +254,44 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -247,36 +254,44 @@ class _GroupedLinear(torch.autograd.Function):
w.main_grad = main_grads[i] w.main_grad = main_grads[i]
weights[i] = w weights[i] = w
# preprocess grad_output # Preprocess grad output
grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1])
grad_output = grad_output.contiguous()
grad_output_mats = torch.split(
grad_output.view(-1, grad_output.shape[-1]), ctx.m_splits
)
grad_output = [None] * ctx.num_gemms grad_output = [None] * ctx.num_gemms
grad_biases = [None] * ctx.num_gemms grad_biases = [None] * ctx.num_gemms
if ctx.fp8: if ctx.fp8:
if ctx.use_bias: if ctx.use_bias:
# unfuse bgrad for now until cast_transpose + dgrad calculation is ready grad_output_mats = torch.split(grad_output_view, ctx.m_splits)
# for Float8BlockQuantizer. recipe = ctx.fp8_recipe
if ctx.fp8_recipe.float8_block_scaling(): if recipe.delayed() or recipe.float8_current_scaling() or recipe.mxfp8():
for i in range(ctx.num_gemms): # Fused bias grad + quantize kernel
grad_biases[i] = grad_output_mats[i].sum(dim=0)
grad_output[i] = ctx.grad_output_quantizers[i](grad_output_mats[i])
else:
for i in range(ctx.num_gemms): for i in range(ctx.num_gemms):
grad_biases[i], grad_output[i] = tex.bgrad_quantize( grad_biases[i], grad_output[i] = tex.bgrad_quantize(
grad_output_mats[i], ctx.grad_output_quantizers[i] grad_output_mats[i],
ctx.grad_output_quantizers[i],
) )
else:
# Unfused bias grad and multi-tensor quantize
for i in range(ctx.num_gemms):
grad_biases[i] = grad_output_mats[i].sum(dim=0)
grad_output = tex.split_quantize(
grad_output_view,
ctx.m_splits,
ctx.grad_output_quantizers,
)
else: else:
grad_output = tex.fused_multi_quantize( # Multi-tensor quantize
grad_output_mats, grad_output = tex.split_quantize(
None, grad_output_view,
ctx.m_splits,
ctx.grad_output_quantizers, ctx.grad_output_quantizers,
TE_DType[ctx.activation_dtype],
) )
else: else:
grad_output = grad_output_mats # Only split grad output. Grad bias is fused with
# wgrad GEMM.
grad_output = torch.split(
cast_if_needed(grad_output_view, ctx.activation_dtype),
ctx.m_splits,
)
if ctx.is_first_microbatch is not None: if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = ( accumulate_wgrad_into_param_main_grad = (
......
...@@ -42,7 +42,6 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase): ...@@ -42,7 +42,6 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
def __new__( def __new__(
cls, cls,
*args,
rowwise_data: Optional[torch.Tensor], rowwise_data: Optional[torch.Tensor],
rowwise_scale_inv: Optional[torch.Tensor], rowwise_scale_inv: Optional[torch.Tensor],
columnwise_data: Optional[torch.Tensor], columnwise_data: Optional[torch.Tensor],
...@@ -50,7 +49,8 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase): ...@@ -50,7 +49,8 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
fp8_dtype: TE_DType, fp8_dtype: TE_DType,
quantizer: Quantizer, quantizer: Quantizer,
is_2D_scaled: bool, is_2D_scaled: bool,
data_format: Float8BlockScaleTensorFormat = Float8BlockScaleTensorFormat.GEMM_READY, data_format: Float8BlockScaleTensorFormat,
*args,
**kwargs, **kwargs,
): ):
instance = super().__new__(cls, *args, **kwargs) instance = super().__new__(cls, *args, **kwargs)
......
...@@ -10,6 +10,7 @@ import math ...@@ -10,6 +10,7 @@ import math
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType from transformer_engine_torch import DType as TE_DType
from transformer_engine_torch import Float8BlockScaleTensorFormat
from transformer_engine.common.recipe import Float8BlockScaling, Recipe from transformer_engine.common.recipe import Float8BlockScaling, Recipe
from ._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from ._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
...@@ -294,6 +295,37 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor): ...@@ -294,6 +295,37 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
holds configuration about quantization and dequantization modes. holds configuration about quantization and dequantization modes.
""" """
# NOTE: We reorder the *args so that we can instantiate a Float8BlockwiseQTensorBase with positional args,
# which significantly reduces the Pybind11 overhead when calling the constructor from C++.
def __new__(
cls,
*args,
rowwise_data: Optional[torch.Tensor],
rowwise_scale_inv: Optional[torch.Tensor],
columnwise_data: Optional[torch.Tensor],
columnwise_scale_inv: Optional[torch.Tensor],
fp8_dtype: TE_DType,
quantizer: Quantizer,
is_2D_scaled: bool,
data_format: tex.Float8BlockScaleTensorFormat = Float8BlockScaleTensorFormat.GEMM_READY,
**kwargs,
):
instance = super().__new__(
cls,
rowwise_data,
rowwise_scale_inv,
columnwise_data,
columnwise_scale_inv,
fp8_dtype,
quantizer,
is_2D_scaled,
data_format,
*args,
**kwargs,
)
return instance
def __repr__(self, *, tensor_contents=None): def __repr__(self, *, tensor_contents=None):
return ( return (
f"Float8BlockwiseQTensor(fp8_dtype={self._fp8_dtype}," f"Float8BlockwiseQTensor(fp8_dtype={self._fp8_dtype},"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment