Unverified Commit 09ffb5d9 authored by Zhongbo Zhu's avatar Zhongbo Zhu Committed by GitHub
Browse files

[PyTorch] Support Bgrad Cast FP8 Fusion for FP8 Current Scaling Recipe (#1558)



* add tex.bgrad_quantize support for CS
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Remove unused import
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarzhongboz <zhongboz@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 8a20d666
......@@ -7,6 +7,7 @@
#include "common.h"
#include "pybind.h"
#include "transformer_engine/cast.h"
#include "transformer_engine/transformer_engine.h"
namespace transformer_engine::pytorch {
......@@ -42,6 +43,28 @@ std::vector<py::object> bgrad_quantize(const at::Tensor& input, py::handle py_qu
workspace = makeTransformerEngineTensor(workspace_data_ptr, workspace.shape(), workspace.dtype());
// Launch kernel
if (detail::IsFloat8CurrentScalingQuantizers(py_quantizer.ptr())) {
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer*>(quantizer.get());
nvte_compute_amax(input_tensor.data(), out_tensor.data(), at::cuda::getCurrentCUDAStream());
// check if we need to do amax reudction (depending on model parallel configs)
if (my_quantizer_cs->with_amax_reduction) {
c10::intrusive_ptr<dist_group_type> process_group_ptr = my_quantizer_cs->amax_reduction_group;
// construct torch tesnor from NVTEBasicTensor without reallocating memory
at::Tensor& amax_tensor_torch = my_quantizer_cs->amax;
std::vector<at::Tensor> tensors = {amax_tensor_torch};
// allreduce amax tensor
c10d::AllreduceOptions allreduce_opts;
allreduce_opts.reduceOp = c10d::ReduceOp::MAX;
process_group_ptr->allreduce(tensors, allreduce_opts)->wait();
}
QuantizationConfigWrapper quant_config;
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);
nvte_compute_scale_from_amax(out_tensor.data(), quant_config, at::cuda::getCurrentCUDAStream());
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
out_tensor.set_amax(nullptr, DType::kFloat32, out_tensor.defaultShape);
}
nvte_quantize_dbias(input_tensor.data(), out_tensor.data(), dbias_tensor.data(), workspace.data(),
at::cuda::getCurrentCUDAStream());
......
......@@ -35,7 +35,6 @@ from ..constants import dist_group_type
from ..tensor import QuantizedTensor, Quantizer
from ..tensor._internal.float8_tensor_base import Float8TensorBase
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer
__all__ = ["initialize_ub", "destroy_ub"]
......@@ -860,9 +859,6 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if ctx.use_bias:
if isinstance(grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase)):
grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0)
elif isinstance(quantizer, Float8CurrentScalingQuantizer):
# FP8 current scaling does not support fused cast + dbias
grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)
else:
grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer)
if not isinstance(grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase)):
......
......@@ -797,14 +797,7 @@ class _LayerNormMLP(torch.autograd.Function):
) # activation in high precision
if ctx.fp8:
# TODO zhongboz: per-tensor current scaling has no bgrad fusion for now
if isinstance(ctx.grad_fc1_output_quantizer, Float8CurrentScalingQuantizer):
fc1_bias_grad = dact.view(-1, dact.shape[-1]).sum(dim=0)
dact = ctx.grad_fc1_output_quantizer(dact)
else:
fc1_bias_grad, dact = tex.bgrad_quantize(
dact, ctx.grad_fc1_output_quantizer
)
fc1_bias_grad, dact = tex.bgrad_quantize(dact, ctx.grad_fc1_output_quantizer)
else:
fuse_gemm_and_bias_fc1_wgrad = (
True # fc1_bias_grad is computed later, fused with wgrad gemm for the FC1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment