Unverified Commit a169e9e7 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Disable fused dbias-quantize kernel for unsupported recipes (#2007)



* Unfused impl for dbias-quantize
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Unfused impl for dact-dbias-quantize
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Disable fused bgrad-quantize for unsupported recipes
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove unfused dbias-quantize impls

Not supported in the core lib.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Support unfused impls in tex functions
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Tweaks
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Remove unused imports
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 26b4b71a
......@@ -2078,7 +2078,7 @@ class TestFusedOps:
# Check that backward operations have been fused
backward_ops = model._module_groups[0]._backward_ops
if with_quantization and quantization in ["fp8_delayed_scaling", "mxfp8"]:
if with_quantization:
assert len(backward_ops) == 2
assert isinstance(backward_ops[0][0], BackwardActivationBias)
assert isinstance(backward_ops[1][0], te_ops.Quantize)
......@@ -2093,6 +2093,7 @@ class TestFusedOps:
if with_quantization:
tols = dtype_tols(tex.DType.kFloat8E4M3)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
db_test = model[1].bias.grad.to(dtype=torch.float64, device="cpu")
......
......@@ -4,80 +4,223 @@
* See LICENSE for license information.
************************************************************************/
#include <ATen/ATen.h>
#include <pybind11/pybind11.h>
#include <utility>
#include <vector>
#include "common.h"
#include "extensions.h"
#include "pybind.h"
#include "transformer_engine/cast.h"
#include "transformer_engine/transformer_engine.h"
namespace transformer_engine::pytorch {
namespace transformer_engine {
namespace pytorch {
std::vector<py::object> bgrad_quantize(const at::Tensor &grad_output, py::handle quantizer) {
using namespace transformer_engine::pytorch::detail;
init_extension();
std::vector<py::object> bgrad_quantize(const at::Tensor& input, py::handle py_quantizer) {
auto quantizer = convert_quantizer(py_quantizer);
// Grad output tensor
auto grad_output_torch = grad_output.contiguous();
const TensorWrapper &grad_output_nvte = makeTransformerEngineTensor(grad_output_torch);
const auto shape = getTensorShape(grad_output_torch);
auto grad_output_dtype = GetTransformerEngineDType(grad_output_torch.scalar_type());
auto input_tensor = makeTransformerEngineTensor(input);
// Construct grad bias tensor
const int64_t bias_size = static_cast<int64_t>(shape.back());
auto grad_bias_torch = allocateTorchTensor(bias_size, grad_output_dtype);
auto grad_bias_nvte = makeTransformerEngineTensor(grad_bias_torch);
// Unquantized impl only requires computing grad bias
if (quantizer.is_none()) {
if (product(shape) == 0) {
grad_bias_torch.zero_();
} else {
at::sum_out(grad_bias_torch, grad_output_torch.reshape({-1, bias_size}), {0});
}
return {py::cast(std::move(grad_bias_torch)), py::cast(std::move(grad_output_torch))};
}
auto dbias = allocateTorchTensor(input.size(-1), input_tensor.dtype());
// Construct grad input tensor
auto quantizer_cpp = convert_quantizer(quantizer);
auto [grad_input_nvte, grad_input_py] = quantizer_cpp->create_tensor(shape, grad_output_dtype);
std::vector<size_t> output_shape;
for (auto s : input.sizes()) {
output_shape.emplace_back(static_cast<size_t>(s));
// Trivial impl if tensors are empty
if (product(shape) == 0) {
grad_bias_torch.zero_();
return {py::cast(std::move(grad_bias_torch)), std::move(grad_input_py)};
}
auto [out_tensor, out] = quantizer->create_tensor(output_shape, input_tensor.dtype());
// Unfused impl if quantizer is not supported
const bool with_fused_dbias_quantize_kernel =
detail::IsFloat8Quantizers(quantizer.ptr()) || detail::IsMXFP8Quantizers(quantizer.ptr());
if (!with_fused_dbias_quantize_kernel) {
at::sum_out(grad_bias_torch, grad_output_torch.reshape({-1, bias_size}), {0});
quantizer_cpp->quantize(grad_output_nvte, grad_input_nvte);
return {py::cast(std::move(grad_bias_torch)), std::move(grad_input_py)};
}
// Query workspace size
TensorWrapper workspace_nvte;
at::Tensor workspace_torch;
auto stream = at::cuda::getCurrentCUDAStream();
NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_dbias(grad_output_nvte.data(), grad_input_nvte.data(), grad_bias_nvte.data(),
workspace_nvte.data(), stream);
});
// Allocate workspace
if (workspace_nvte.ndim() > 0 && workspace_nvte.numel() > 0) {
workspace_torch = allocateSpace(workspace_nvte.shape(), workspace_nvte.dtype());
workspace_nvte = makeTransformerEngineTensor(workspace_torch.data_ptr(), workspace_nvte.shape(),
workspace_nvte.dtype());
}
// Launch fused kernel
NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_dbias(grad_output_nvte.data(), grad_input_nvte.data(), grad_bias_nvte.data(),
workspace_nvte.data(), stream);
});
return {py::cast(std::move(grad_bias_torch)), std::move(grad_input_py)};
}
namespace {
std::vector<py::object> dact_dbias(
void (*dact_dbias_func)(const NVTETensor, const NVTETensor, NVTETensor, NVTETensor, NVTETensor,
cudaStream_t),
void (*dact_func)(const NVTETensor, const NVTETensor, NVTETensor, cudaStream_t),
at::Tensor grad_output_torch, at::Tensor act_input_torch, py::handle quantizer_py) {
using namespace transformer_engine::pytorch::detail;
init_extension();
// Grad output and activation input tensors
grad_output_torch = grad_output_torch.contiguous();
const TensorWrapper &grad_output_nvte = makeTransformerEngineTensor(grad_output_torch);
const auto output_shape = getTensorShape(grad_output_torch);
auto grad_output_dtype = GetTransformerEngineDType(grad_output_torch.scalar_type());
act_input_torch = act_input_torch.contiguous();
const TensorWrapper &act_input_nvte = makeTransformerEngineTensor(act_input_torch);
const auto input_shape = getTensorShape(act_input_torch);
// Construct tensors
auto quantizer_cpp = convert_quantizer(quantizer_py);
auto [grad_input_nvte, grad_input_py] =
quantizer_cpp->create_tensor(input_shape, grad_output_dtype);
const int64_t bias_size = static_cast<int64_t>(input_shape.back());
auto grad_bias_torch = allocateTorchTensor(bias_size, grad_output_dtype);
auto grad_bias_nvte = makeTransformerEngineTensor(grad_bias_torch);
// Return immediately if tensors are empty
if (product(output_shape) == 0) {
return {py::cast(dbias.zero_()), out};
grad_bias_torch.zero_();
return {py::cast(std::move(grad_bias_torch)), std::move(grad_input_py)};
}
// Choose implementation
enum class Impl { UNFUSED, FUSED_DACT_DBIAS_QUANTIZE, FUSED_DACT_AMAX };
Impl impl = Impl::UNFUSED;
if (detail::IsFloat8Quantizers(quantizer_py.ptr()) ||
detail::IsMXFP8Quantizers(quantizer_py.ptr())) {
impl = Impl::FUSED_DACT_DBIAS_QUANTIZE;
} else if (detail::IsFloat8CurrentScalingQuantizers(quantizer_py.ptr())) {
impl = Impl::FUSED_DACT_AMAX;
}
auto dbias_tensor = makeTransformerEngineTensor(dbias);
// Query workspace size and allocate workspace
transformer_engine::TensorWrapper workspace;
// Perform compute
auto stream = at::cuda::getCurrentCUDAStream();
switch (impl) {
case Impl::UNFUSED:
// Unfused dact, dbias, quantize
{
auto [temp_nvte, temp_py] =
NoneQuantizer(py::none()).create_tensor(input_shape, grad_output_dtype);
NVTE_SCOPED_GIL_RELEASE({
dact_func(grad_output_nvte.data(), act_input_nvte.data(), temp_nvte.data(), stream);
});
const auto temp_torch = temp_py.cast<at::Tensor>();
at::sum_out(grad_bias_torch, temp_torch.reshape({-1, bias_size}), {0});
quantizer_cpp->quantize(temp_nvte, grad_input_nvte);
break;
}
case Impl::FUSED_DACT_DBIAS_QUANTIZE:
// Fused dact-dbias-quantize kernel
{
// Query workspace size
TensorWrapper workspace_nvte;
NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_dbias(input_tensor.data(), out_tensor.data(), dbias_tensor.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
dact_dbias_func(grad_output_nvte.data(), act_input_nvte.data(), grad_input_nvte.data(),
grad_bias_nvte.data(), workspace_nvte.data(), stream);
});
void* workspace_data_ptr = nullptr;
if (workspace.shape().ndim > 0) {
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
workspace_data_ptr = workspace_data.data_ptr();
// Allocate workspace
at::Tensor workspace_torch;
if (workspace_nvte.ndim() > 0 && workspace_nvte.numel() > 0) {
workspace_torch = allocateSpace(workspace_nvte.shape(), workspace_nvte.dtype());
workspace_nvte = makeTransformerEngineTensor(
workspace_torch.data_ptr(), workspace_nvte.shape(), workspace_nvte.dtype());
}
workspace = makeTransformerEngineTensor(workspace_data_ptr, workspace.shape(), workspace.dtype());
// Launch kernel
if (detail::IsFloat8CurrentScalingQuantizers(py_quantizer.ptr())) {
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer*>(quantizer.get());
NVTE_SCOPED_GIL_RELEASE({
nvte_compute_amax(input_tensor.data(), out_tensor.data(), at::cuda::getCurrentCUDAStream());
dact_dbias_func(grad_output_nvte.data(), act_input_nvte.data(), grad_input_nvte.data(),
grad_bias_nvte.data(), workspace_nvte.data(), stream);
});
// check if we need to do amax reudction (depending on model parallel configs)
if (my_quantizer_cs->with_amax_reduction) {
c10::intrusive_ptr<dist_group_type> process_group_ptr = my_quantizer_cs->amax_reduction_group;
// construct torch tesnor from NVTEBasicTensor without reallocating memory
at::Tensor& amax_tensor_torch = my_quantizer_cs->amax;
std::vector<at::Tensor> tensors = {amax_tensor_torch};
// allreduce amax tensor
c10d::AllreduceOptions allreduce_opts;
allreduce_opts.reduceOp = c10d::ReduceOp::MAX;
process_group_ptr->allreduce(tensors, allreduce_opts)->wait();
break;
}
QuantizationConfigWrapper quant_config;
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);
case Impl::FUSED_DACT_AMAX:
// Fused dact-amax kernel, unfused dbias and quantize
{
auto *quantizer_cpp_cs = dynamic_cast<Float8CurrentScalingQuantizer *>(quantizer_cpp.get());
NVTE_CHECK(quantizer_cpp_cs != nullptr,
"Invalid quantizer for fused dact-amax kernel impl");
auto [temp_nvte, temp_py] =
quantizer_cpp_cs->create_hp_tensor_with_amax(input_shape, grad_output_dtype);
NVTE_SCOPED_GIL_RELEASE({
nvte_compute_scale_from_amax(out_tensor.data(), quant_config,
at::cuda::getCurrentCUDAStream());
dact_func(grad_output_nvte.data(), act_input_nvte.data(), temp_nvte.data(), stream);
});
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
out_tensor.set_amax(nullptr, DType::kFloat32, out_tensor.defaultShape);
const auto temp_torch = temp_py.cast<at::Tensor>();
at::sum_out(grad_bias_torch, temp_torch.reshape({-1, bias_size}), {0});
quantizer_cpp_cs->quantize_with_amax(temp_nvte, grad_input_nvte);
break;
}
NVTE_SCOPED_GIL_RELEASE({
nvte_quantize_dbias(input_tensor.data(), out_tensor.data(), dbias_tensor.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
});
default:
NVTE_ERROR("Invalid implementation");
}
return {py::cast(std::move(grad_bias_torch)), std::move(grad_input_py)};
}
} // namespace
std::vector<py::object> dbias_dgelu(const at::Tensor &grad_output, const at::Tensor &act_input,
py::handle quantizer) {
return dact_dbias(nvte_quantize_dbias_dgelu, nvte_dgelu, grad_output, act_input, quantizer);
}
std::vector<py::object> dbias_dsilu(const at::Tensor &grad_output, const at::Tensor &act_input,
py::handle quantizer) {
return dact_dbias(nvte_quantize_dbias_dsilu, nvte_dsilu, grad_output, act_input, quantizer);
}
std::vector<py::object> dbias_drelu(const at::Tensor &grad_output, const at::Tensor &act_input,
py::handle quantizer) {
return dact_dbias(nvte_quantize_dbias_drelu, nvte_drelu, grad_output, act_input, quantizer);
}
std::vector<py::object> dbias_dqgelu(const at::Tensor &grad_output, const at::Tensor &act_input,
py::handle quantizer) {
return dact_dbias(nvte_quantize_dbias_dqgelu, nvte_dqgelu, grad_output, act_input, quantizer);
}
return {py::cast(dbias), out};
std::vector<py::object> dbias_dsrelu(const at::Tensor &grad_output, const at::Tensor &act_input,
py::handle quantizer) {
return dact_dbias(nvte_quantize_dbias_dsrelu, nvte_dsrelu, grad_output, act_input, quantizer);
}
} // namespace transformer_engine::pytorch
} // namespace pytorch
} // namespace transformer_engine
......@@ -587,66 +587,5 @@ std::vector<py::object> split_quantize(const at::Tensor &tensor,
return output_py_list;
}
template <void (*func)(const NVTETensor, const NVTETensor, NVTETensor, NVTETensor, NVTETensor,
cudaStream_t)>
std::vector<py::object> dbias_dact(const at::Tensor &grad_output, const at::Tensor &act_input,
py::handle quantizer) {
init_extension();
auto my_quantizer = convert_quantizer(quantizer);
auto grad_tensor = makeTransformerEngineTensor(grad_output);
auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_tensor.dtype());
auto act_input_tensor = makeTransformerEngineTensor(act_input);
const auto &shape = convertShape(grad_tensor.shape());
auto [dact_tensor, dact] = my_quantizer->create_tensor(shape, act_input_tensor.dtype());
auto dbias_tensor = makeTransformerEngineTensor(grad_bias);
// Query workspace size and allocate workspace
transformer_engine::TensorWrapper workspace;
NVTE_SCOPED_GIL_RELEASE({
func(grad_tensor.data(), act_input_tensor.data(), dact_tensor.data(), dbias_tensor.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
});
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
workspace =
makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype());
// Launch kernel
NVTE_SCOPED_GIL_RELEASE({
func(grad_tensor.data(), act_input_tensor.data(), dact_tensor.data(), dbias_tensor.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
});
return {py::cast(grad_bias), dact};
}
std::vector<py::object> dbias_dgelu(const at::Tensor &grad_output, const at::Tensor &act_input,
py::handle quantizer) {
return dbias_dact<nvte_quantize_dbias_dgelu>(grad_output, act_input, quantizer);
}
std::vector<py::object> dbias_dsilu(const at::Tensor &grad_output, const at::Tensor &act_input,
py::handle quantizer) {
return dbias_dact<nvte_quantize_dbias_dsilu>(grad_output, act_input, quantizer);
}
std::vector<py::object> dbias_drelu(const at::Tensor &grad_output, const at::Tensor &act_input,
py::handle quantizer) {
return dbias_dact<nvte_quantize_dbias_drelu>(grad_output, act_input, quantizer);
}
std::vector<py::object> dbias_dqgelu(const at::Tensor &grad_output, const at::Tensor &act_input,
py::handle quantizer) {
return dbias_dact<nvte_quantize_dbias_dqgelu>(grad_output, act_input, quantizer);
}
std::vector<py::object> dbias_dsrelu(const at::Tensor &grad_output, const at::Tensor &act_input,
py::handle quantizer) {
return dbias_dact<nvte_quantize_dbias_dsrelu>(grad_output, act_input, quantizer);
}
} // namespace pytorch
} // namespace transformer_engine
......@@ -10,14 +10,8 @@ from typing import Optional
import torch
import transformer_engine_torch as tex
from transformer_engine.pytorch.ops.op import (
BasicOperation,
OperationContext,
)
from ...utils import (
canonicalize_device,
canonicalize_dtype,
)
from ..op import BasicOperation, OperationContext
from ...utils import canonicalize_device, canonicalize_dtype
from ...tensor import Quantizer
......@@ -141,10 +135,10 @@ class Bias(BasicOperation):
dy = grad_output
if dy.dim() > 1:
quantizer = ctx.grad_input_quantizer
if quantizer is not None:
db, dy = tex.bgrad_quantize(dy, quantizer)
else:
if quantizer is None:
db = dy.sum(tuple(range(dy.dim() - 1)))
else:
db, dy = tex.bgrad_quantize(dy, quantizer)
else:
db = dy
return dy, (db,)
......@@ -104,7 +104,7 @@ def fuse_backward_activation_bias(
"""
# Check if recipe supports bias activation fusion
if recipe is None or not (recipe.delayed() or recipe.mxfp8()):
if recipe is None:
return ops
# Scan through ops, fusing if possible
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment