"vscode:/vscode.git/clone" did not exist on "923797fea4d80a4dac4409ece3c450b84d5ba001"
Unverified Commit c30e961f authored by Zhongbo Zhu's avatar Zhongbo Zhu Committed by GitHub
Browse files

[PyTorch][MoE] Reduce CPU Overhead By Fuse Torch Empty Calls (#1793)



* finish python ref impl for bulk alloc
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* c++ bulk alloc worked, still draft version
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* clean up
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* resolve rebase conflict
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add license
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* use shared_ptr to auto manage reference count
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* attempt to fix misc training error
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* attempt to handle case where experts get zero token
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* updated with fused C++ function calls
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* clean up
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* experiment with reducing py object construction time
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* fix seg fault bug in inference mode
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* fix lint
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* fuse torch split into bulk alloc
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* clean up
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* rebase to latest main
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* fix unit test failure
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* fix lint error
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* refactor create_tensor to use get_scale_shape
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* refactor quantize to call quantize_cpp
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* Implement separate functions for multi-tensor quantize and split + multi-tensor quantize
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Update grouped linear module with fused split+quantize func
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Move multi-tensor quantize func to cast.cpp
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Do not expose quantizer helper function externally
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warnings
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Revert cuDNN frontend commit
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* fix corner cases with zero tokens
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* add comments
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

---------
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
parent 7db72dbc
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import argparse
import torch
import torch.utils.benchmark as benchmark
import pandas as pd
import pathlib
from transformer_engine.pytorch.module import GroupedLinear
from transformer_engine.common.recipe import Float8BlockScaling
from transformer_engine.pytorch.fp8 import fp8_autocast
from contextlib import nullcontext
RECIPES = {
"bf16": None,
"fp8_sub_channel": Float8BlockScaling(),
}
def run_linear_multiple_steps(layer, x, m_splits, mode, gradient, run_num_steps=1, recipe=None):
assert mode in ["fwd_only", "fwd_bwd"]
fp8_context = (
fp8_autocast(enabled=True, fp8_recipe=recipe) if recipe is not None else nullcontext()
)
# print(f"fp8_context: {fp8_context} and is it nullcontext? {isinstance(fp8_context, nullcontext)}")
if mode == "fwd_only":
with torch.no_grad(), fp8_context:
for i in range(run_num_steps):
y_q = layer.forward(
x,
m_splits,
is_first_microbatch=(i == 0),
)
return y_q
else:
# reset gradients
layer.zero_grad()
x.grad = None
with fp8_context:
for i in range(run_num_steps):
label = f"step_{i}"
torch.cuda.nvtx.range_push(label)
y_q = layer.forward(
x,
m_splits,
is_first_microbatch=(i == 0),
)
y_q.backward(gradient)
torch.cuda.nvtx.range_pop()
grads_q = []
grads_q.append(x.grad)
# remaining derivatives are in respect to model parameters
for p in layer.parameters():
if p.requires_grad:
grads_q.append(p.grad)
return y_q, grads_q
def benchmark_linear(
x,
ws,
m_splits,
bias,
recipe_name,
mode,
num_gemms=4,
):
params_dtype = torch.bfloat16
recipe = RECIPES[recipe_name]
in_features = x.shape[1]
out_features = ws[0].shape[0]
gradient = torch.ones((x.shape[0], out_features), dtype=torch.bfloat16, device=x.device)
layer = GroupedLinear(
num_gemms,
in_features,
out_features,
bias=bias is not None,
params_dtype=params_dtype,
)
layer = layer.to("cuda")
with torch.no_grad():
for i in range(num_gemms):
weight_i = getattr(layer, f"weight{i}")
weight_i.copy_(ws[i])
if bias is not None:
bias_i = getattr(layer, f"bias{i}")
bias_i.copy_(bias)
num_microbatches = 32
label = f"{recipe_name}_{'grouped'}"
torch.cuda.nvtx.range_push(label)
timing = benchmark.Timer(
stmt=(
"run_linear_multiple_steps(layer, x, m_splits, mode, gradient, num_microbatches,"
" recipe)"
),
globals={
"run_linear_multiple_steps": run_linear_multiple_steps,
"layer": layer,
"x": x,
"m_splits": m_splits,
"mode": mode,
"gradient": gradient,
"num_microbatches": num_microbatches,
"recipe": recipe,
},
num_threads=1,
).blocked_autorange(min_run_time=5)
print(f"{recipe_name}: {timing} \n")
timing_ms = timing.median * 1000 / num_microbatches
return timing_ms
def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4):
data = []
assert not use_bias, "Bias is not supported for GroupedLinear benchmark"
print(f"========== Benchmarking {recipe_name} ==========")
for m, k, n in mkns:
device = "cuda"
x = torch.randn((m, k), dtype=torch.bfloat16, device=device, requires_grad=True)
ws = [torch.randn((n, k), dtype=torch.bfloat16, device=device) for _ in range(num_gemms)]
assert m % num_gemms == 0
m_splits = [m // num_gemms] * num_gemms
# Bias is not supported for GroupedLinear benchmark
bias = None
# Run the benchmark
print(f"fwd_m={m}, fwd_k={k}, fwd_n={n}")
grouped_fwd_bwd_timing_ms = benchmark_linear(
x,
ws,
m_splits,
bias,
recipe_name,
mode="fwd_bwd",
num_gemms=num_gemms,
)
# Append the results
data.append(
[
m,
k,
n,
recipe_name,
num_gemms,
grouped_fwd_bwd_timing_ms,
]
)
df = pd.DataFrame(
data=data,
columns=[
"m",
"k",
"n",
"recipe",
"num_gemms",
"grouped_fwd_bwd_time_ms",
],
)
print(df, "\n")
return df
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--profile", action="store_true", help="Enable profiling mode")
parser.add_argument(
"--output_dir",
type=str,
default="benchmark_output/",
help="output path for report",
)
args = parser.parse_args()
use_bias = False
# Set the MKN values to benchmark
mkns = []
for m in [1024]:
# for m in [4096, 8192, 16384]:
# for n in [1024, 2048, 4096, 8192, 16384]:
for n in [3072]:
for k in [4096]:
mkns.append((m, k, n))
# recipe_list = [
# "bf16", "fp8_sub_channel",
# ]
recipe_list = [
"fp8_sub_channel",
]
# num_gemms_list = [16, 32]
num_gemms_list = [4]
if args.profile:
# nsys profile --output=./benchmarks/linear/mkn_4096_4096_4096_numgemm_1_bf16 --trace=cuda,nvtx,cudnn,cublas python benchmarks/linear/benchmark_grouped_linear.py --profile
# nsys profile --output=./benchmarks/linear/mkn_8192_8192_8192_numgemm_32_bf16 --trace=cuda,nvtx,cudnn,cublas python benchmarks/linear/benchmark_grouped_linear.py --profile
# nsys profile --output=./benchmarks/linear/mkn_4096_4096_4096_numgemm_8_fp8_sub_channel --trace=cuda,nvtx,cudnn,cublas python benchmarks/linear/benchmark_grouped_linear.py --profile
# nsys profile --output=./benchmarks/linear/mkn_8192_8192_8192_numgemm_2_fp8_sub_channel --trace=cuda,nvtx,cudnn,cublas python benchmarks/linear/benchmark_grouped_linear.py --profile
mkns = [(4096, 4096, 4096)]
recipe_list = ["fp8_sub_channel"]
# recipe_list = ["bf16"]
num_gemms_list = [8]
torch.autograd.profiler.emit_nvtx(record_shapes=True).__enter__()
# Initialize a dataframe to store the results
df_linears = pd.DataFrame()
# Run the fp8 benchmarks
for num_gemms in num_gemms_list:
print(f"========== Benchmarking with num_gemms={num_gemms} ==========")
for recipe_name in recipe_list:
df = run_benchmark_linear(
mkns,
recipe_name,
use_bias,
num_gemms=num_gemms,
)
df_linears = pd.concat([df_linears, df])
print(df_linears)
if args.profile:
torch.autograd.profiler.emit_nvtx().__exit__(None, None, None)
......@@ -197,6 +197,8 @@ class Float8BlockQuantizer : public Quantizer {
std::pair<TensorWrapper, py::object> create_tensor(
const std::vector<size_t>& shape, DType dtype,
std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
std::vector<size_t> get_scale_shape(const std::vector<size_t>& shape, bool columnwise) const;
};
class MXFP8Quantizer : public Quantizer {
......
......@@ -108,10 +108,6 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
* Transpose
**************************************************************************************************/
std::vector<py::object> fused_multi_quantize(std::vector<at::Tensor> input_list,
std::optional<std::vector<py::object>> output_list,
std::vector<py::handle> quantizer_list, DType otype);
at::Tensor fp8_transpose(at::Tensor input, DType otype,
std::optional<at::Tensor> output = std::nullopt);
......@@ -182,10 +178,17 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
**************************************************************************************************/
py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::object &output,
std::optional<at::Tensor> noop);
std::optional<at::Tensor> noop_flag);
py::object dequantize(const py::handle &input, DType otype);
std::vector<py::object> multi_tensor_quantize(const std::vector<at::Tensor> &tensor_list,
std::vector<py::handle> quantizer_list);
std::vector<py::object> split_quantize(const at::Tensor &tensor,
const std::vector<int> &split_sections,
std::vector<py::handle> quantizer_list);
/***************************************************************************************************
* Bias gradient fusions
**************************************************************************************************/
......
......@@ -6,60 +6,51 @@
#include "transformer_engine/cast.h"
#include <cstdint>
#include <memory>
#include <optional>
#include <tuple>
#include <utility>
#include <vector>
#include "../extensions.h"
#include "common.h"
#include "pybind.h"
#include "transformer_engine/transformer_engine.h"
namespace transformer_engine::pytorch {
namespace transformer_engine {
namespace pytorch {
py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::object& output,
std::optional<at::Tensor> noop) {
init_extension();
auto my_quantizer = convert_quantizer(quantizer);
auto input_tensor = tensor.contiguous();
namespace {
const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor);
const auto& te_input_shape = te_input.shape();
std::vector<size_t> input_shape(te_input_shape.data, te_input_shape.data + te_input_shape.ndim);
auto fake_tensor_type = tensor.scalar_type();
if (!detail::IsFloatingPointType(fake_tensor_type)) {
fake_tensor_type = at::kFloat;
}
TensorWrapper te_output;
py::object out;
if (output.is_none()) {
DType fake_te_type = GetTransformerEngineDType(fake_tensor_type);
std::tie(te_output, out) = my_quantizer->create_tensor(input_shape, fake_te_type);
} else {
out = output;
te_output = makeTransformerEngineTensor(output, quantizer);
}
std::vector<size_t> get_tensor_shape(const TensorWrapper &tensor) {
const auto &shape = tensor.shape();
return std::vector<size_t>(shape.data, shape.data + shape.ndim);
}
TensorWrapper te_noop;
if (noop.has_value()) {
te_noop = makeTransformerEngineTensor(*noop);
} else {
te_noop = TensorWrapper();
void quantize_impl(const TensorWrapper &input, py::handle &quantizer_py,
std::unique_ptr<Quantizer> &quantizer_cpp, TensorWrapper &output,
TensorWrapper &noop_flag) {
// Check tensor dims
NVTE_CHECK(get_tensor_shape(input) == get_tensor_shape(output),
"Input tensor (shape=", get_tensor_shape(input),
") and output tensor (shape=", get_tensor_shape(output), ") do not match");
if (input.numel() == 0) {
return;
}
if (te_output.numel() == 0) return out;
// Recipe-specific configuration
QuantizationConfigWrapper quant_config;
quant_config.set_noop_tensor(te_noop.data());
if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer*>(my_quantizer.get());
NVTE_SCOPED_GIL_RELEASE({
nvte_compute_amax(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream());
});
quant_config.set_noop_tensor(noop_flag.data());
if (detail::IsFloat8CurrentScalingQuantizers(quantizer_py.ptr())) {
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer *>(quantizer_cpp.get());
NVTE_SCOPED_GIL_RELEASE(
{ nvte_compute_amax(input.data(), output.data(), at::cuda::getCurrentCUDAStream()); });
// check if we need to do amax reudction (depending on model parallel configs)
if (my_quantizer_cs->with_amax_reduction) {
c10::intrusive_ptr<dist_group_type> process_group_ptr = my_quantizer_cs->amax_reduction_group;
// construct torch tesnor from NVTEBasicTensor without reallocating memory
at::Tensor& amax_tensor_torch = my_quantizer_cs->amax;
at::Tensor &amax_tensor_torch = my_quantizer_cs->amax;
std::vector<at::Tensor> tensors = {amax_tensor_torch};
// allreduce amax tensor
c10d::AllreduceOptions allreduce_opts;
......@@ -72,37 +63,70 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);
NVTE_SCOPED_GIL_RELEASE({
nvte_compute_scale_from_amax(te_output.data(), quant_config,
at::cuda::getCurrentCUDAStream());
nvte_compute_scale_from_amax(output.data(), quant_config, at::cuda::getCurrentCUDAStream());
});
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape);
} else if (detail::IsFloat8BlockwiseQuantizers(quantizer.ptr())) {
auto my_quantizer_bw = static_cast<Float8BlockQuantizer*>(my_quantizer.get());
// set amax ptr to null in output TensorWrapper to avoid atomic amax updates in kernel
output.set_amax(nullptr, DType::kFloat32, output.defaultShape);
} else if (detail::IsFloat8BlockwiseQuantizers(quantizer_py.ptr())) {
auto my_quantizer_bw = static_cast<Float8BlockQuantizer *>(quantizer_cpp.get());
quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon);
if (my_quantizer_bw->all_gather_usage) {
quant_config.set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat::COMPACT);
}
}
// Perform quantization
NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_v2(te_input.data(), te_output.data(), quant_config,
at::cuda::getCurrentCUDAStream());
nvte_quantize_v2(input.data(), output.data(), quant_config, at::cuda::getCurrentCUDAStream());
});
}
return out;
} // namespace
py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::object &output,
std::optional<at::Tensor> noop_flag) {
// Convert quantizer to C++ object
auto quantizer_cpp = convert_quantizer(quantizer);
// Convert input tensor to C++ object
auto input_contiguous = tensor.contiguous();
const auto input_cpp = makeTransformerEngineTensor(input_contiguous);
// Initialize output tensor
TensorWrapper output_cpp;
py::object output_py;
if (output.is_none()) {
const auto shape = get_tensor_shape(input_cpp);
const auto fake_dtype = input_cpp.dtype();
std::tie(output_cpp, output_py) = quantizer_cpp->create_tensor(shape, fake_dtype);
} else {
output_py = output;
output_cpp = makeTransformerEngineTensor(output_py, quantizer);
}
// Initialize no-op flag
TensorWrapper noop_flag_cpp;
if (noop_flag.has_value()) {
noop_flag_cpp = makeTransformerEngineTensor(*noop_flag);
}
// Perform quantization
quantize_impl(input_cpp, quantizer, quantizer_cpp, output_cpp, noop_flag_cpp);
return output_py;
}
py::object dequantize(const py::handle& input, transformer_engine::DType otype) {
py::object dequantize(const py::handle &input, transformer_engine::DType otype) {
init_extension();
const auto none = py::none();
const auto& input_tensor = makeTransformerEngineTensor(input, none);
const auto &input_tensor = makeTransformerEngineTensor(input, none);
NoneQuantizer q(none);
const auto& shape = convertShape(input_tensor.shape());
const auto &shape = convertShape(input_tensor.shape());
auto [out_tensor, out] = q.create_tensor(shape, otype);
......@@ -113,9 +137,348 @@ py::object dequantize(const py::handle& input, transformer_engine::DType otype)
return out;
}
namespace {
void multi_tensor_quantize_impl(const std::vector<TensorWrapper> &input_list,
std::vector<py::handle> &quantizer_py_list,
std::vector<std::unique_ptr<Quantizer>> &quantizer_cpp_list,
std::vector<TensorWrapper> &output_list) {
// Check number of tensors
const size_t num_tensors = input_list.size();
NVTE_CHECK(quantizer_py_list.size() == num_tensors, "Expected ", num_tensors,
" Python quantizers, but got ", quantizer_py_list.size());
NVTE_CHECK(quantizer_cpp_list.size() == num_tensors, "Expected ", num_tensors,
" C++ quantizers, but got ", quantizer_cpp_list.size());
NVTE_CHECK(output_list.size() == num_tensors, "Expected ", num_tensors,
" output tensors, but got ", output_list.size());
// Choose implementation
// Note: Currently only have fused kernel for FP8 delayed scaling
bool with_fused_kernel = true;
for (size_t i = 0; i < num_tensors; i++) {
if (!detail::IsFloat8Quantizers(quantizer_py_list[i].ptr())) {
with_fused_kernel = false;
break;
}
if (nvte_tensor_columnwise_data(output_list[i].data()) == nullptr) {
with_fused_kernel = false;
break;
}
}
// Launch TE kernel
if (with_fused_kernel) {
// Fused kernel for multi-tensor quantize
std::vector<NVTETensor> nvte_tensor_input_list;
std::vector<NVTETensor> nvte_tensor_output_list;
for (size_t i = 0; i < num_tensors; ++i) {
nvte_tensor_input_list.push_back(input_list[i].data());
nvte_tensor_output_list.push_back(output_list[i].data());
}
NVTE_SCOPED_GIL_RELEASE({
nvte_multi_cast_transpose(nvte_tensor_input_list.size(), nvte_tensor_input_list.data(),
nvte_tensor_output_list.data(), at::cuda::getCurrentCUDAStream());
});
} else {
// Quantize kernels individually
TensorWrapper dummy_noop_flag;
for (size_t i = 0; i < num_tensors; ++i) {
quantize_impl(input_list[i], quantizer_py_list[i], quantizer_cpp_list[i], output_list[i],
dummy_noop_flag);
}
}
}
} // namespace
std::vector<py::object> multi_tensor_quantize(const std::vector<at::Tensor> &tensor_list,
std::vector<py::handle> quantizer_list) {
// Check number of tensors
const size_t num_tensors = tensor_list.size();
NVTE_CHECK(quantizer_list.size() == num_tensors, "Expected ", num_tensors,
" quantizers, but got ", quantizer_list.size());
// Convert quantizers to C++ objects
std::vector<std::unique_ptr<Quantizer>> quantizer_cpp_list;
for (size_t i = 0; i < num_tensors; i++) {
quantizer_cpp_list.push_back(convert_quantizer(quantizer_list[i]));
}
// Initialize input and output tensors
std::vector<TensorWrapper> input_cpp_list;
std::vector<TensorWrapper> output_cpp_list;
std::vector<py::object> output_py_list;
for (size_t i = 0; i < num_tensors; ++i) {
// Convert input tensor to C++ object
const auto &input_py = tensor_list[i];
NVTE_CHECK(input_py.is_contiguous(), "Input tensor ", i, " is not contiguous");
input_cpp_list.emplace_back(makeTransformerEngineTensor(input_py));
const auto &input_cpp = input_cpp_list.back();
const auto input_shape = input_cpp.shape();
const auto input_dtype = GetTransformerEngineDType(input_py.scalar_type());
// Construct output tensor
std::vector<size_t> output_shape(input_shape.data, input_shape.data + input_shape.ndim);
auto [output_cpp, output_py] = quantizer_cpp_list[i]->create_tensor(output_shape, input_dtype);
output_cpp_list.emplace_back(std::move(output_cpp));
output_py_list.emplace_back(std::move(output_py));
}
// Perform multi-tensor quantization
multi_tensor_quantize_impl(input_cpp_list, quantizer_list, quantizer_cpp_list, output_cpp_list);
return output_py_list;
}
namespace {
std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_fp8_blockwise_tensors(
std::vector<std::vector<size_t>> &shape_list, std::vector<py::handle> &quantizer_py_list,
std::vector<Float8BlockQuantizer *> &quantizer_cpp_list) {
init_extension();
std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> retval;
auto &tensor_py_list = std::get<0>(retval);
auto &tensor_cpp_list = std::get<1>(retval);
// Number of tensors
const size_t num_tensors = shape_list.size();
if (num_tensors == 0) {
return retval;
}
// Quantization parameters
const auto rowwise_usage = quantizer_cpp_list[0]->rowwise_usage;
const auto columnwise_usage = quantizer_cpp_list[0]->columnwise_usage;
const auto scaling_mode = quantizer_cpp_list[0]->get_scaling_mode();
const auto is_2D_scaled = scaling_mode == NVTE_BLOCK_SCALING_2D;
const auto fp8_dtype = quantizer_cpp_list[0]->dtype;
constexpr size_t fp8_elem_size = 1;
constexpr size_t scale_elem_size = 4;
// Helper function to construct tensor view
// Note: Deleter holds a shared_ptr for the buffer, so the buffer
// will survive until all views are deleted.
auto make_torch_view = [](std::shared_ptr<at::Tensor> &buffer, const std::vector<size_t> &shape,
size_t offset, at::ScalarType dtype) -> at::Tensor {
std::vector<int64_t> shape_int64(shape.begin(), shape.end());
// in the case where full buffer is empty because local rank receives no tokens for all the experts
// then the data_ptr is nullptr, we need to return an empty tensor instead of calling from_blob
// but in the case where some experts receive tokens, some not, we want to leverage from_blob
// as much as possible to avoid CPU overhead
if (buffer->data_ptr<uint8_t>() == nullptr) {
return at::empty(shape_int64, at::device(at::kCUDA).dtype(dtype));
}
return at::from_blob(
buffer->data_ptr<uint8_t>() + offset, shape_int64,
[buffer](void *) {}, // deleter holds shared_ptr
at::device(at::kCUDA).dtype(dtype));
};
// Allocate row-wise data
std::vector<at::Tensor> rowwise_data_list, rowwise_scale_list;
std::vector<std::vector<size_t>> rowwise_data_shapes, rowwise_scale_shapes;
if (rowwise_usage) {
// Tensor sizes
for (size_t i = 0; i < num_tensors; ++i) {
rowwise_data_shapes.emplace_back(shape_list[i]);
rowwise_scale_shapes.emplace_back(
quantizer_cpp_list[i]->get_scale_shape(shape_list[i], false));
}
// Offsets in full buffer
size_t buffer_size = 0;
std::vector<size_t> data_offsets, scale_offsets;
for (size_t i = 0; i < num_tensors; ++i) {
buffer_size = roundup(buffer_size, 256); // align to 256B
data_offsets.push_back(buffer_size);
buffer_size += product(rowwise_data_shapes[i]) * fp8_elem_size;
}
for (size_t i = 0; i < num_tensors; ++i) {
buffer_size = roundup(buffer_size, 16); // align to 16B
scale_offsets.push_back(buffer_size);
buffer_size += product(rowwise_scale_shapes[i]) * scale_elem_size;
}
// Allocate full buffer
auto buffer = std::make_shared<at::Tensor>(
at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)));
// Construct tensor views
for (size_t i = 0; i < num_tensors; ++i) {
rowwise_data_list.emplace_back(
make_torch_view(buffer, rowwise_data_shapes[i], data_offsets[i], torch::kUInt8));
rowwise_scale_list.emplace_back(
make_torch_view(buffer, rowwise_scale_shapes[i], scale_offsets[i], torch::kFloat32));
}
}
// Allocate column-wise data
std::vector<at::Tensor> columnwise_data_list, columnwise_scale_list;
std::vector<std::vector<size_t>> columnwise_data_shapes, columnwise_scale_shapes;
if (columnwise_usage) {
// Tensor sizes
for (size_t i = 0; i < num_tensors; ++i) {
columnwise_data_shapes.emplace_back();
auto &shape = columnwise_data_shapes.back();
shape.push_back(shape_list[i].back());
for (size_t j = 0; j < shape_list[i].size() - 1; ++j) {
shape.push_back(shape_list[i][j]);
}
columnwise_scale_shapes.emplace_back(
quantizer_cpp_list[i]->get_scale_shape(shape_list[i], true));
}
// Offsets in full buffer
size_t buffer_size = 0;
std::vector<size_t> data_offsets, scale_offsets;
for (size_t i = 0; i < num_tensors; ++i) {
buffer_size = roundup(buffer_size, 256); // align to 256B
data_offsets.push_back(buffer_size);
buffer_size += product(columnwise_data_shapes[i]) * fp8_elem_size;
}
for (size_t i = 0; i < num_tensors; ++i) {
buffer_size = roundup(buffer_size, 16); // align to 16B
scale_offsets.push_back(buffer_size);
buffer_size += product(columnwise_scale_shapes[i]) * scale_elem_size;
}
// Allocate full buffer
auto buffer = std::make_shared<at::Tensor>(
at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)));
// Construct tensor views
for (size_t i = 0; i < num_tensors; ++i) {
columnwise_data_list.emplace_back(
make_torch_view(buffer, columnwise_data_shapes[i], data_offsets[i], torch::kUInt8));
columnwise_scale_list.emplace_back(
make_torch_view(buffer, columnwise_scale_shapes[i], scale_offsets[i], torch::kFloat32));
}
}
// Construct FP8 block-wise tensors
py::handle Float8BlockwiseQTensorClass(
reinterpret_cast<PyObject *>(Float8BlockwiseQTensorBasePythonClass));
for (size_t i = 0; i < num_tensors; ++i) {
// Create tensor objects with proper reference counting
py::object rowwise_data = rowwise_usage ? py::cast(rowwise_data_list[i]) : py::none();
py::object rowwise_scale = rowwise_usage ? py::cast(rowwise_scale_list[i]) : py::none();
py::object columnwise_data =
(columnwise_usage ? py::cast(columnwise_data_list[i]) : py::none());
py::object columnwise_scale =
(columnwise_usage ? py::cast(columnwise_scale_list[i]) : py::none());
// Construct Python tensor
tensor_py_list.emplace_back(Float8BlockwiseQTensorClass(
rowwise_data, rowwise_scale, columnwise_data, columnwise_scale, fp8_dtype,
quantizer_py_list[i], is_2D_scaled, Float8BlockScaleTensorFormat::GEMM_READY));
// Construct C++ tensor
tensor_cpp_list.emplace_back(makeTransformerEngineTensor(
rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr,
columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr,
rowwise_usage ? rowwise_data_shapes[i] : std::vector<size_t>{},
columnwise_usage ? columnwise_data_shapes[i] : std::vector<size_t>{}, fp8_dtype, nullptr,
nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr,
columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr,
rowwise_usage ? rowwise_scale_shapes[i] : std::vector<size_t>{},
columnwise_usage ? columnwise_scale_shapes[i] : std::vector<size_t>{}, scaling_mode));
}
return retval;
}
} // namespace
std::vector<py::object> split_quantize(const at::Tensor &tensor,
const std::vector<int> &split_sections,
std::vector<py::handle> quantizer_list) {
init_extension();
// Check number of tensors
const size_t num_splits = split_sections.size();
NVTE_CHECK(quantizer_list.size() == num_splits, "Expected ", num_splits, " quantizers, but got ",
quantizer_list.size());
if (num_splits == 0) {
return {};
}
// Input tensor properties
auto input_py = tensor.contiguous();
uint8_t *input_dptr = reinterpret_cast<uint8_t *>(input_py.data_ptr());
auto input_dtype = GetTransformerEngineDType(input_py.scalar_type());
std::vector<size_t> input_shape;
size_t input_size = 1;
for (const auto &d : input_py.sizes()) {
input_shape.push_back(d);
input_size *= d;
}
NVTE_CHECK(input_shape.size() > 0, "Input tensor has 0 dims");
// Split input tensor along dim 0
std::vector<TensorWrapper> input_list;
std::vector<std::vector<size_t>> split_shapes;
size_t dim0_offset = 0;
const size_t dim0_stride =
input_shape[0] == 0 ? 0 : input_py.element_size() * input_size / input_shape[0];
for (size_t i = 0; i < num_splits; ++i) {
NVTE_CHECK(split_sections[i] >= 0, "Attempted to split tensor with shape=", input_shape,
" along dim 0 with split_sections=", split_sections);
NVTE_CHECK(dim0_offset + split_sections[i] <= input_shape[0],
"Attempted to split tensor with shape=", input_shape,
" along dim 0 with split_sections=", split_sections);
split_shapes.push_back(input_shape);
auto &split_shape = split_shapes.back();
split_shape[0] = split_sections[i];
void *split_dptr = static_cast<void *>(input_dptr + dim0_offset * dim0_stride);
input_list.emplace_back(makeTransformerEngineTensor(split_dptr, split_shape, input_dtype));
dim0_offset += split_sections[i];
}
// Convert quantizers to C++ objects
std::vector<std::unique_ptr<Quantizer>> quantizer_cpp_list;
for (size_t i = 0; i < num_splits; i++) {
quantizer_cpp_list.push_back(convert_quantizer(quantizer_list[i]));
}
// For FP8 block-scaling, we construct output tensors with bulk allocations
bool use_fused_bulk_alloc = true;
for (size_t i = 0; i < quantizer_list.size(); i++) {
if (!detail::IsFloat8BlockwiseQuantizers(quantizer_list[i].ptr())) {
use_fused_bulk_alloc = false;
break;
}
}
// Allocate output tensors
std::vector<TensorWrapper> output_cpp_list;
std::vector<py::object> output_py_list;
if (!use_fused_bulk_alloc) {
// Allocate output tensors individually
for (size_t i = 0; i < num_splits; ++i) {
auto [output_cpp, output_py] =
quantizer_cpp_list[i]->create_tensor(split_shapes[i], input_dtype);
output_cpp_list.emplace_back(std::move(output_cpp));
output_py_list.emplace_back(std::move(output_py));
}
} else {
// FP8 block-scaling: construct output tensors with bulk allocations
std::vector<Float8BlockQuantizer *> blockwise_quantizers;
for (auto &quantizer : quantizer_cpp_list) {
blockwise_quantizers.push_back(static_cast<Float8BlockQuantizer *>(quantizer.get()));
}
std::tie(output_py_list, output_cpp_list) =
bulk_allocate_fp8_blockwise_tensors(split_shapes, quantizer_list, blockwise_quantizers);
}
// Perform multi-tensor quantization
multi_tensor_quantize_impl(input_list, quantizer_list, quantizer_cpp_list, output_cpp_list);
return output_py_list;
}
template <void (*func)(const NVTETensor, const NVTETensor, NVTETensor, NVTETensor, NVTETensor,
cudaStream_t)>
std::vector<py::object> dbias_dact(const at::Tensor& grad_output, const at::Tensor& act_input,
std::vector<py::object> dbias_dact(const at::Tensor &grad_output, const at::Tensor &act_input,
py::handle quantizer) {
init_extension();
auto my_quantizer = convert_quantizer(quantizer);
......@@ -125,7 +488,7 @@ std::vector<py::object> dbias_dact(const at::Tensor& grad_output, const at::Tens
auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_tensor.dtype());
auto act_input_tensor = makeTransformerEngineTensor(act_input);
const auto& shape = convertShape(grad_tensor.shape());
const auto &shape = convertShape(grad_tensor.shape());
auto [dact_tensor, dact] = my_quantizer->create_tensor(shape, act_input_tensor.dtype());
auto dbias_tensor = makeTransformerEngineTensor(grad_bias);
......@@ -149,29 +512,30 @@ std::vector<py::object> dbias_dact(const at::Tensor& grad_output, const at::Tens
return {py::cast(grad_bias), dact};
}
std::vector<py::object> dbias_dgelu(const at::Tensor& grad_output, const at::Tensor& act_input,
std::vector<py::object> dbias_dgelu(const at::Tensor &grad_output, const at::Tensor &act_input,
py::handle quantizer) {
return dbias_dact<nvte_quantize_dbias_dgelu>(grad_output, act_input, quantizer);
}
std::vector<py::object> dbias_dsilu(const at::Tensor& grad_output, const at::Tensor& act_input,
std::vector<py::object> dbias_dsilu(const at::Tensor &grad_output, const at::Tensor &act_input,
py::handle quantizer) {
return dbias_dact<nvte_quantize_dbias_dsilu>(grad_output, act_input, quantizer);
}
std::vector<py::object> dbias_drelu(const at::Tensor& grad_output, const at::Tensor& act_input,
std::vector<py::object> dbias_drelu(const at::Tensor &grad_output, const at::Tensor &act_input,
py::handle quantizer) {
return dbias_dact<nvte_quantize_dbias_drelu>(grad_output, act_input, quantizer);
}
std::vector<py::object> dbias_dqgelu(const at::Tensor& grad_output, const at::Tensor& act_input,
std::vector<py::object> dbias_dqgelu(const at::Tensor &grad_output, const at::Tensor &act_input,
py::handle quantizer) {
return dbias_dact<nvte_quantize_dbias_dqgelu>(grad_output, act_input, quantizer);
}
std::vector<py::object> dbias_dsrelu(const at::Tensor& grad_output, const at::Tensor& act_input,
std::vector<py::object> dbias_dsrelu(const at::Tensor &grad_output, const at::Tensor &act_input,
py::handle quantizer) {
return dbias_dact<nvte_quantize_dbias_dsrelu>(grad_output, act_input, quantizer);
}
} // namespace transformer_engine::pytorch
} // namespace pytorch
} // namespace transformer_engine
......@@ -12,7 +12,9 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <stdexcept>
#include <memory>
#include <optional>
#include <vector>
#include "../common.h"
#include "../extensions.h"
......@@ -199,10 +201,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("weight"), py::arg("eps"), py::arg("ln_out"), py::arg("quantizer"),
py::arg("otype"), py::arg("sm_margin"), py::arg("zero_centered_gamma"));
m.def("rmsnorm_bwd", &transformer_engine::pytorch::rmsnorm_bwd, "Backward of RMSNorm");
m.def("fused_multi_quantize", &transformer_engine::pytorch::fused_multi_quantize,
"Fused Multi-tensor Cast + Transpose", py::arg("input_list"), py::arg("output_list"),
py::arg("quantizer_list"), py::arg("otype"));
m.def("multi_tensor_quantize", &transformer_engine::pytorch::multi_tensor_quantize,
"Multi-tensor quantize", py::arg("tensor_list"), py::arg("quantizer_list"));
m.def("split_quantize", &transformer_engine::pytorch::split_quantize,
"Split and multi-tensor quantize", py::arg("tensor"), py::arg("split_sections"),
py::arg("quantizer_list"));
m.def("te_general_grouped_gemm", &transformer_engine::pytorch::te_general_grouped_gemm,
"Grouped GEMM");
m.def("fp8_transpose", &transformer_engine::pytorch::fp8_transpose, "Transpose with FP8 I/O",
......
......@@ -4,80 +4,16 @@
* See LICENSE for license information.
************************************************************************/
#include <pybind.h>
#include <optional>
#include <vector>
#include "../extensions.h"
#include "pybind.h"
namespace transformer_engine::pytorch {
std::vector<py::object> fused_multi_quantize(std::vector<at::Tensor> input_list,
std::optional<std::vector<py::object>> output_list,
std::vector<py::handle> quantizer_list, DType otype) {
init_extension();
std::vector<NVTETensor> nvte_tensor_input_list;
std::vector<NVTETensor> nvte_tensor_output_list;
std::vector<py::object> py_output_objects_list;
std::vector<TensorWrapper> tensor_wrappers;
if (output_list.has_value()) {
py_output_objects_list = output_list.value();
}
// Choose implementation
// Note: Currently only have fused kernel for FP8 cast-transpose
bool with_fused_kernel = true;
// create TE tensors from input
for (size_t i = 0; i < input_list.size(); i++) {
auto input_tensor = makeTransformerEngineTensor(input_list[i]);
const NVTEShape input_shape = input_tensor.shape();
TensorWrapper output_tensor;
if (!detail::IsFloat8Quantizers(quantizer_list[i].ptr())) {
with_fused_kernel = false;
}
if (output_list == std::nullopt) {
std::unique_ptr<Quantizer> quantizer = convert_quantizer(quantizer_list[i]);
std::vector<size_t> output_shape(input_shape.data, input_shape.data + input_shape.ndim);
py::object o;
std::tie(output_tensor, o) = quantizer->create_tensor(output_shape, otype);
py_output_objects_list.push_back(o);
} else {
output_tensor = makeTransformerEngineTensor((*output_list)[i], quantizer_list[i]);
}
if (input_tensor.numel() == 0) continue;
nvte_tensor_output_list.emplace_back(output_tensor.data());
nvte_tensor_input_list.emplace_back(input_tensor.data());
tensor_wrappers.emplace_back(std::move(input_tensor));
tensor_wrappers.emplace_back(std::move(output_tensor));
}
// Check tensor lists
NVTE_CHECK(nvte_tensor_output_list.size() == nvte_tensor_input_list.size(),
"Number of input and output tensors must match");
for (size_t i = 0; i < nvte_tensor_output_list.size(); i++) {
if (nvte_tensor_columnwise_data(nvte_tensor_output_list[i]) == nullptr) {
with_fused_kernel = false;
break;
}
}
// Launch TE kernel
if (with_fused_kernel) {
NVTE_SCOPED_GIL_RELEASE({
nvte_multi_cast_transpose(nvte_tensor_input_list.size(), nvte_tensor_input_list.data(),
nvte_tensor_output_list.data(), at::cuda::getCurrentCUDAStream());
});
} else {
for (size_t i = 0; i < py_output_objects_list.size(); i++) {
quantize(input_list[i], quantizer_list[i], py_output_objects_list[i], std::nullopt);
}
}
return py_output_objects_list;
}
namespace transformer_engine {
namespace pytorch {
at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional<at::Tensor> output) {
init_extension();
......@@ -108,4 +44,5 @@ at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional<at::Tensor
return out;
}
} // namespace transformer_engine::pytorch
} // namespace pytorch
} // namespace transformer_engine
......@@ -283,10 +283,8 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
const std::vector<size_t>& shape, DType dtype, std::optional<at::Tensor> rowwise_data) const {
using namespace pybind11::literals;
std::vector<int64_t> torch_shape;
size_t numel = 1;
for (auto s : shape) {
torch_shape.emplace_back(static_cast<int64_t>(s));
numel *= s;
}
TensorWrapper tensor(this->get_scaling_mode());
......@@ -296,10 +294,6 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
opts = opts.dtype(torch::kUInt8).device(torch::kCUDA);
scale_opts = scale_opts.dtype(torch::kFloat32).device(torch::kCUDA);
size_t k_dim = torch_shape.size() == 0 ? 1u : torch_shape.back();
size_t m_dim = numel / k_dim;
constexpr size_t kBlockLen = 128;
Float8BlockScaleTensorFormat data_format =
(all_gather_usage ? Float8BlockScaleTensorFormat::COMPACT
: Float8BlockScaleTensorFormat::GEMM_READY);
......@@ -310,30 +304,9 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
} else {
data_rowwise = at::empty(torch_shape, opts);
}
size_t sinv0 = 0;
size_t sinv1 = 0;
if (block_scaling_dim == 2) {
// 2D scaling is always GEMM_READY for now
NVTE_CHECK(data_format == Float8BlockScaleTensorFormat::GEMM_READY,
"2D scaling is always GEMM_READY for now.");
sinv0 = (m_dim + kBlockLen - 1) / kBlockLen;
sinv1 = roundup((k_dim + kBlockLen - 1) / kBlockLen, 4);
} else if (block_scaling_dim == 1) {
// 1D scaling can be GEMM_READY or COMPACT
bool rowwise_compact = data_format == Float8BlockScaleTensorFormat::COMPACT;
// default rowwise scaling factor shape already transpose the scaling factor so it's GEMM_READY
sinv0 = (k_dim + kBlockLen - 1) / kBlockLen;
sinv1 = rowwise_compact ? m_dim : roundup(m_dim, 4);
// if the rowwise format is compact, the scaling factor is not be transposed
if (rowwise_compact) {
std::swap(sinv0, sinv1);
}
} else {
NVTE_ERROR(
"Unsupported block_scaling_dim in create_tensor rowwise. "
"Expected 1 or 2. Got ",
block_scaling_dim);
}
auto scale_shape = get_scale_shape(shape, false);
size_t sinv0 = scale_shape[0];
size_t sinv1 = scale_shape[1];
scale_inv_rowwise =
at::empty({static_cast<int64_t>(sinv0), static_cast<int64_t>(sinv1)}, scale_opts);
tensor.set_rowwise_data(data_rowwise.data_ptr(), this->dtype, shape);
......@@ -364,27 +337,9 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
columnwise_shape = shape;
}
}
size_t sinv0 = 0;
size_t sinv1 = 0;
if (block_scaling_dim == 2) {
// 2D scaling is always GEMM_READY for now
NVTE_CHECK(data_format == Float8BlockScaleTensorFormat::GEMM_READY,
"2D scaling is always GEMM_READY for now.");
sinv0 = (k_dim + kBlockLen - 1) / kBlockLen;
sinv1 = roundup((m_dim + kBlockLen - 1) / kBlockLen, 4);
} else if (block_scaling_dim == 1) {
bool columnwise_compact = data_format == Float8BlockScaleTensorFormat::COMPACT;
sinv0 = (m_dim + kBlockLen - 1) / kBlockLen;
sinv1 = columnwise_compact ? k_dim : roundup(k_dim, 4);
// GEMM READY case: scaling factor is [sinv0, sinv1], already transposed here for CuBLAS
// for COMPACT case, since we apply 128x1 scaling here without transposing columnwise data, scaling factor is also [sinv0, sinv1]
// so no need to swap sinv0 and sinv1 here
} else {
NVTE_ERROR(
"Unsupported block_scaling_dim in create_tensor columnwise. "
"Expected 1 or 2. Got ",
block_scaling_dim);
}
auto scale_shape = get_scale_shape(shape, true);
size_t sinv0 = scale_shape[0];
size_t sinv1 = scale_shape[1];
data_colwise = at::empty(torch_columnwise_shape, opts);
scale_inv_colwise =
at::empty({static_cast<int64_t>(sinv0), static_cast<int64_t>(sinv1)}, scale_opts);
......@@ -418,6 +373,81 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
return {std::move(tensor), std::move(ret)};
}
std::vector<size_t> Float8BlockQuantizer::get_scale_shape(const std::vector<size_t>& shape,
bool columnwise) const {
size_t numel = 1;
for (auto s : shape) {
numel *= s;
}
size_t k_dim = shape.size() == 0 ? 1u : shape.back();
size_t m_dim = numel / k_dim;
constexpr size_t kBlockLen = 128;
Float8BlockScaleTensorFormat data_format =
(all_gather_usage ? Float8BlockScaleTensorFormat::COMPACT
: Float8BlockScaleTensorFormat::GEMM_READY);
std::vector<size_t> scale_shape;
bool rowwise_usage = !columnwise;
if (rowwise_usage) {
// rowwise scaling factor shape
size_t sinv0 = 0;
size_t sinv1 = 0;
if (block_scaling_dim == 2) {
// 2D scaling is always GEMM_READY for now
NVTE_CHECK(data_format == Float8BlockScaleTensorFormat::GEMM_READY,
"2D scaling is always GEMM_READY for now.");
sinv0 = (m_dim + kBlockLen - 1) / kBlockLen;
sinv1 = roundup((k_dim + kBlockLen - 1) / kBlockLen, 4);
} else if (block_scaling_dim == 1) {
// 1D scaling can be GEMM_READY or COMPACT
bool rowwise_compact = data_format == Float8BlockScaleTensorFormat::COMPACT;
// default rowwise scaling factor shape already transpose the scaling factor so it's GEMM_READY
sinv0 = (k_dim + kBlockLen - 1) / kBlockLen;
sinv1 = rowwise_compact ? m_dim : roundup(m_dim, 4);
// if the rowwise format is compact, the scaling factor is not be transposed
if (rowwise_compact) {
std::swap(sinv0, sinv1);
}
} else {
NVTE_CHECK(false,
"Unsupported block_scaling_dim in create_tensor rowwise."
"Expected 1 or 2. Got ",
block_scaling_dim);
}
scale_shape = {sinv0, sinv1};
} else {
// columnwise scaling factor shape
size_t sinv0 = 0;
size_t sinv1 = 0;
if (block_scaling_dim == 2) {
// 2D scaling is always GEMM_READY for now
NVTE_CHECK(data_format == Float8BlockScaleTensorFormat::GEMM_READY,
"2D scaling is always GEMM_READY for now.");
sinv0 = (k_dim + kBlockLen - 1) / kBlockLen;
sinv1 = roundup((m_dim + kBlockLen - 1) / kBlockLen, 4);
} else if (block_scaling_dim == 1) {
// 1D scaling can be GEMM_READY or COMPACT
bool columnwise_compact = data_format == Float8BlockScaleTensorFormat::COMPACT;
sinv0 = (m_dim + kBlockLen - 1) / kBlockLen;
sinv1 = columnwise_compact ? k_dim : roundup(k_dim, 4);
// GEMM READY case: scaling factor is [sinv0, sinv1], already transposed here for CuBLAS
// for COMPACT case, since we apply 128x1 scaling here without transposing columnwise data, scaling factor is also [sinv0, sinv1]
// so no need to swap sinv0 and sinv1 here
} else {
NVTE_CHECK(false,
"Unsupported block_scaling_dim in create_tensor columnwise."
"Expected 1 or 2. Got ",
block_scaling_dim);
}
scale_shape = {sinv0, sinv1};
}
return scale_shape;
}
MXFP8Quantizer::MXFP8Quantizer(const py::handle& quantizer) : Quantizer(quantizer) {
this->dtype = quantizer.attr("dtype").cast<DType>();
}
......
......@@ -24,7 +24,6 @@ from ..fp8 import FP8GlobalStateManager
from ..utils import (
divide,
cast_if_needed,
assert_dim_for_fp8_exec,
clear_tensor_data,
init_method_constant,
requires_grad,
......@@ -38,7 +37,7 @@ from ..distributed import (
from ..cpp_extensions import (
general_grouped_gemm,
)
from ..constants import GemmParallelModes, dist_group_type, TE_DType
from ..constants import GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ..cpu_offload import is_cpu_offload_enabled
......@@ -87,20 +86,9 @@ class _GroupedLinear(torch.autograd.Function):
weights = weights_and_biases[:num_gemms]
biases = weights_and_biases[num_gemms:]
device = inp.device
# Make sure input dimensions are compatible
in_features = weights[0].shape[-1]
assert inp.shape[-1] == in_features, "GEMM not possible"
inputmats = torch.split(inp.view(-1, in_features), m_splits)
if fp8:
assert_dim_for_fp8_exec(*inputmats, *weights)
# Cast input to expected dtype
inputmats_no_fp8 = [cast_if_needed(mat, activation_dtype) for mat in inputmats]
inputmats = []
weight_requires_grad = weights[0].requires_grad
# Configure quantizers
if input_quantizers[0] is not None:
for input_quantizer in input_quantizers:
input_quantizer.set_usage(
......@@ -120,17 +108,25 @@ class _GroupedLinear(torch.autograd.Function):
for output_quantizer in output_quantizers:
output_quantizer.set_usage(rowwise=True, columnwise=False)
fprop_gemm_use_split_accumulator = _2X_ACC_FPROP
if fp8:
recipe = FP8GlobalStateManager.get_fp8_recipe()
if hasattr(recipe, "fp8_gemm_fprop"):
fprop_gemm_use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator
inputmats = tex.fused_multi_quantize(
inputmats_no_fp8, None, input_quantizers, TE_DType[activation_dtype]
# Initialize input tensors
in_features = weights[0].size(-1)
if inp.size(-1) != in_features:
raise ValueError(
f"Input tensor (shape={tuple(inp.size())}) is not compatible with "
f"weight tensor (shape={tuple(weights[0].size())})"
)
weights_fp8 = []
bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype
inp_view = inp.reshape(-1, in_features)
inputmats: list
if fp8:
inputmats = tex.split_quantize(inp_view, m_splits, input_quantizers)
else:
inputmats = torch.split(cast_if_needed(inp_view, activation_dtype), m_splits)
# Initialize weights
weights_fp8: list
if fp8:
# FP8 cast to workspace buffer
weights_fp8 = []
update_workspace = is_first_microbatch is None or is_first_microbatch
for i in range(num_gemms):
weight_fp8 = module.get_weight_workspace(
......@@ -143,18 +139,29 @@ class _GroupedLinear(torch.autograd.Function):
weights_fp8.append(weight_fp8)
else:
inputmats = inputmats_no_fp8
bias_dtype = activation_dtype
weights_fp8 = [cast_if_needed(weight, activation_dtype) for weight in weights]
# Initialize biases
bias_dtype = activation_dtype
if fp8 and activation_dtype == torch.float32:
bias_dtype = torch.bfloat16 # FP8 GEMM only supports BF16/FP16 bias
biases = [cast_if_needed(bias, bias_dtype) for bias in biases] if use_bias else biases
# Initialize output tensor
out = torch.empty(
[sum(m_splits), weights_fp8[0].size(0)],
dtype=activation_dtype,
device=device,
)
# Choose whether to use split accumulator
use_split_accumulator = _2X_ACC_FPROP
if fp8:
recipe = FP8GlobalStateManager.get_fp8_recipe()
if hasattr(recipe, "fp8_gemm_fprop"):
use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator
# Perform GEMM
_ = general_grouped_gemm(
weights_fp8,
inputmats,
......@@ -165,7 +172,7 @@ class _GroupedLinear(torch.autograd.Function):
m_splits=m_splits,
bias=biases,
use_bias=use_bias,
use_split_accumulator=fprop_gemm_use_split_accumulator,
use_split_accumulator=use_split_accumulator,
)
if fp8_calibration:
......@@ -247,36 +254,44 @@ class _GroupedLinear(torch.autograd.Function):
w.main_grad = main_grads[i]
weights[i] = w
# preprocess grad_output
grad_output = grad_output.contiguous()
grad_output_mats = torch.split(
grad_output.view(-1, grad_output.shape[-1]), ctx.m_splits
)
# Preprocess grad output
grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1])
grad_output = [None] * ctx.num_gemms
grad_biases = [None] * ctx.num_gemms
if ctx.fp8:
if ctx.use_bias:
# unfuse bgrad for now until cast_transpose + dgrad calculation is ready
# for Float8BlockQuantizer.
if ctx.fp8_recipe.float8_block_scaling():
grad_output_mats = torch.split(grad_output_view, ctx.m_splits)
recipe = ctx.fp8_recipe
if recipe.delayed() or recipe.float8_current_scaling() or recipe.mxfp8():
# Fused bias grad + quantize kernel
for i in range(ctx.num_gemms):
grad_biases[i] = grad_output_mats[i].sum(dim=0)
grad_output[i] = ctx.grad_output_quantizers[i](grad_output_mats[i])
grad_biases[i], grad_output[i] = tex.bgrad_quantize(
grad_output_mats[i],
ctx.grad_output_quantizers[i],
)
else:
# Unfused bias grad and multi-tensor quantize
for i in range(ctx.num_gemms):
grad_biases[i], grad_output[i] = tex.bgrad_quantize(
grad_output_mats[i], ctx.grad_output_quantizers[i]
grad_biases[i] = grad_output_mats[i].sum(dim=0)
grad_output = tex.split_quantize(
grad_output_view,
ctx.m_splits,
ctx.grad_output_quantizers,
)
else:
grad_output = tex.fused_multi_quantize(
grad_output_mats,
None,
# Multi-tensor quantize
grad_output = tex.split_quantize(
grad_output_view,
ctx.m_splits,
ctx.grad_output_quantizers,
TE_DType[ctx.activation_dtype],
)
else:
grad_output = grad_output_mats
# Only split grad output. Grad bias is fused with
# wgrad GEMM.
grad_output = torch.split(
cast_if_needed(grad_output_view, ctx.activation_dtype),
ctx.m_splits,
)
if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = (
......
......@@ -42,7 +42,6 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
def __new__(
cls,
*args,
rowwise_data: Optional[torch.Tensor],
rowwise_scale_inv: Optional[torch.Tensor],
columnwise_data: Optional[torch.Tensor],
......@@ -50,7 +49,8 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
fp8_dtype: TE_DType,
quantizer: Quantizer,
is_2D_scaled: bool,
data_format: Float8BlockScaleTensorFormat = Float8BlockScaleTensorFormat.GEMM_READY,
data_format: Float8BlockScaleTensorFormat,
*args,
**kwargs,
):
instance = super().__new__(cls, *args, **kwargs)
......
......@@ -10,6 +10,7 @@ import math
import torch
import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType
from transformer_engine_torch import Float8BlockScaleTensorFormat
from transformer_engine.common.recipe import Float8BlockScaling, Recipe
from ._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
......@@ -294,6 +295,37 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
holds configuration about quantization and dequantization modes.
"""
# NOTE: We reorder the *args so that we can instantiate a Float8BlockwiseQTensorBase with positional args,
# which significantly reduces the Pybind11 overhead when calling the constructor from C++.
def __new__(
cls,
*args,
rowwise_data: Optional[torch.Tensor],
rowwise_scale_inv: Optional[torch.Tensor],
columnwise_data: Optional[torch.Tensor],
columnwise_scale_inv: Optional[torch.Tensor],
fp8_dtype: TE_DType,
quantizer: Quantizer,
is_2D_scaled: bool,
data_format: tex.Float8BlockScaleTensorFormat = Float8BlockScaleTensorFormat.GEMM_READY,
**kwargs,
):
instance = super().__new__(
cls,
rowwise_data,
rowwise_scale_inv,
columnwise_data,
columnwise_scale_inv,
fp8_dtype,
quantizer,
is_2D_scaled,
data_format,
*args,
**kwargs,
)
return instance
def __repr__(self, *, tensor_contents=None):
return (
f"Float8BlockwiseQTensor(fp8_dtype={self._fp8_dtype},"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment