Unverified Commit a9500617 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Implement fused kernel for FP8 scale update (#593)



* Implement fused kernel for FP8 scale update
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add fused kernel for amax and scale update

Add unit test.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Replace paddle.fluid imports with paddle.base
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Move fused kernel to core library
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug test
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use FP8 update kernel in Paddle
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug FP8 scale update in Paddle
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix lint errors
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug Paddle test failures
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Make update kernel in-place for PyTorch
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Revert cudnn-frontend commit
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 379c1ee3
......@@ -1046,6 +1046,7 @@ def test_amax_and_scale_update(update_weight_scale_inv):
num_gemm = 6
history_len = 1024
recipe = DelayedScaling()
fp8_dtype = tex.DType.kFloat8E4M3
fp8_max = recipe.fp8_format.value.max_fwd
non_weight_mask = paddle.to_tensor([True, False] * (num_gemm // 2))
......@@ -1073,12 +1074,13 @@ def test_amax_and_scale_update(update_weight_scale_inv):
scale_actual = paddle.zeros_like(scale_tensor)
scale_inv_actual = paddle.zeros_like(scale_tensor)
if update_weight_scale_inv:
non_weight_mask = paddle.empty([0])
tex.amax_and_scale_update_inplace(_amax_history=amax_history_tensor,
_scale=scale_actual,
_scale_inv=scale_inv_actual,
non_weight_mask=non_weight_mask,
update_weight_scale_inv=update_weight_scale_inv,
fp8_max=fp8_max,
fp8_dtype=int(fp8_dtype),
margin=0.,
amax_compute="max")
......
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from typing import Optional
import pytest
import torch
import transformer_engine.common.recipe
import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import (
FP8GlobalStateManager,
amax_and_scale_update,
get_default_fp8_recipe,
)
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
class TestFP8Recipe:
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
@pytest.mark.parametrize("amax_history_len", [1, 31, 1024])
@pytest.mark.parametrize("amax_compute_algo", ["max", "most_recent"])
@pytest.mark.parametrize("is_first_microbatch", [None, True, False])
def test_amax_and_scale_update(
self,
amax_history_len: int,
amax_compute_algo: str,
is_first_microbatch: Optional[bool],
margin: int = 2,
):
# Construct linear module
fp8_format = transformer_engine.common.recipe.Format.HYBRID
recipe = transformer_engine.common.recipe.DelayedScaling(
margin=margin,
interval=1,
fp8_format=fp8_format,
amax_history_len=amax_history_len,
amax_compute_algo=amax_compute_algo,
)
with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
module = te.Linear(16, 16)
y = module(torch.zeros([16, 16], device="cuda"))
y.backward(torch.zeros_like(y))
# Get amax history and scaling factors
fp8_meta = module.fp8_meta
forward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True)
amax_history_forward = fp8_meta[forward_key].amax_history
scale_forward = fp8_meta[forward_key].scale
scale_inv_forward = fp8_meta[forward_key].scale_inv
backward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=False)
amax_history_backward = fp8_meta[backward_key].amax_history
scale_backward = fp8_meta[backward_key].scale
scale_inv_backward = fp8_meta[backward_key].scale_inv
# Tweak amax history and scaling factors
amax_history_forward.copy_(2 * torch.rand_like(amax_history_forward) + 0.5)
if amax_history_len > 1:
amax_history_forward[1, 0].fill_(3)
scale_forward.copy_(2 * torch.rand_like(scale_forward) + 0.5)
scale_inv_forward.copy_(torch.reciprocal(scale_forward))
amax_history_backward.copy_(2 * torch.rand_like(amax_history_backward) + 0.5)
scale_backward.copy_(2 * torch.rand_like(scale_backward) + 0.5)
scale_inv_backward.copy_(torch.reciprocal(scale_backward))
# Expected amax history after update
ref_amax_history_forward = torch.roll(amax_history_forward, -1, dims=0)
ref_amax_history_forward[0].zero_()
ref_amax_history_backward = torch.roll(amax_history_backward, -1, dims=0)
ref_amax_history_backward[0].zero_()
# Expected scale and scale inverse
if amax_compute_algo == "max":
ref_amax_forward = amax_history_forward.max(dim=0).values
ref_amax_backward = amax_history_backward.max(dim=0).values
elif amax_compute_algo == "most_recent":
ref_amax_forward = amax_history_forward[0]
ref_amax_backward = amax_history_backward[0]
else:
raise ValueError(f"{amax_compute_algo=} is not supported")
ref_scale_forward = (fp8_format.value.max_fwd / ref_amax_forward) / (2 ** margin)
ref_scale_backward = (fp8_format.value.max_bwd / ref_amax_backward) / (2 ** margin)
ref_scale_inv_forward = torch.reciprocal(ref_scale_forward)
update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch
if not update_weight_scale_inv:
ref_scale_inv_forward[1].copy_(scale_inv_forward[1])
ref_scale_inv_backward = torch.reciprocal(ref_scale_backward)
# Make sure we are not trivially passing tests
if amax_history_len > 1:
with pytest.raises(AssertionError):
torch.testing.assert_close(
amax_history_forward[1:],
ref_amax_history_forward[1:],
)
with pytest.raises(AssertionError):
torch.testing.assert_close(
scale_forward,
ref_scale_forward,
)
with pytest.raises(AssertionError):
torch.testing.assert_close(
scale_inv_forward,
ref_scale_inv_forward,
)
if amax_history_len > 1:
with pytest.raises(AssertionError):
torch.testing.assert_close(
fp8_meta[backward_key].amax_history[1:],
ref_amax_history_backward[1:],
)
with pytest.raises(AssertionError):
torch.testing.assert_close(
fp8_meta[backward_key].scale,
ref_scale_backward,
)
with pytest.raises(AssertionError):
torch.testing.assert_close(
fp8_meta[backward_key].scale_inv,
ref_scale_inv_backward,
)
# Perform forward and backward pass to update fp8_meta
with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
x = torch.zeros([16, 16], device="cuda")
y = module(x, is_first_microbatch=is_first_microbatch)
y.backward(torch.zeros_like(y))
# Check that fp8_meta matches expected values
torch.testing.assert_close(
fp8_meta[forward_key].amax_history[1:],
ref_amax_history_forward[1:],
)
torch.testing.assert_close(
fp8_meta[forward_key].scale,
ref_scale_forward,
)
torch.testing.assert_close(
fp8_meta[forward_key].scale_inv,
ref_scale_inv_forward,
)
torch.testing.assert_close(
fp8_meta[backward_key].amax_history[1:],
ref_amax_history_backward[1:],
)
torch.testing.assert_close(
fp8_meta[backward_key].scale,
ref_scale_backward,
)
torch.testing.assert_close(
fp8_meta[backward_key].scale_inv,
ref_scale_inv_backward,
)
......@@ -34,7 +34,8 @@ list(APPEND transformer_engine_SOURCES
fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu
fused_rope/fused_rope.cu)
fused_rope/fused_rope.cu
recipe/delayed_scaling.cu)
add_library(transformer_engine SHARED ${transformer_engine_SOURCES})
target_include_directories(transformer_engine PUBLIC
"${CMAKE_CURRENT_SOURCE_DIR}/include")
......
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file recipe.h
* \brief Functions handling FP8 recipes.
*/
#ifndef TRANSFORMER_ENGINE_RECIPE_H_
#define TRANSFORMER_ENGINE_RECIPE_H_
#include "transformer_engine.h"
#ifdef __cplusplus
extern "C" {
#endif
/*! \brief Update FP8 scaling factors with delayed scaling recipe.
*
* The amax history is rotated by -1 (e.g. the first entry shifts to
* the last, the last entry shifts to the second to last) and the
* first entry is set to zero. The scaling factor is estimated so the
* FP8 tensor's maximum absolute value is
* @f$ 2^{-\text{margin}} \text{max}_\text{fp8\_dtype} @f$.
*
* \param[in] amax_history History of maximum absolute values.
* Shape: [history_length, num_scales]
* \param[in] scale Scaling factor for casting to FP8. Shape: [num_scales]
* \param[in] scale_inv Scaling factor for casting from FP8. Shape: [num_scales]
* \param[in] scale_inv_mask Boolean mask indicating scale_inv entries to update. May be
* empty, in which case all scale_inv entries are updated.
* Shape: [num_scales]
* \param[out] updated_amax_history Updated history of maximum absolute values.
* Shape: [history_length, num_scales]
* \param[out] updated_scale Updated scaling factor for casting to FP8.
* Shape: [num_scales]
* \param[out] updated_scale_inv Updated scaling factor for casting from FP8.
* Shape: [num_scales]
* \param[in] amax_compute_algo Method to reduce amax history. Options are "max" and
* "most_recent".
* \param[in] fp8_dtype FP8 datatype.
* \param[in] margin Scaling factor margin.
* \param[in] stream CUDA stream.
*/
void nvte_delayed_scaling_recipe_amax_and_scale_update(const NVTETensor amax_history,
const NVTETensor scale,
const NVTETensor scale_inv,
const NVTETensor scale_inv_mask,
NVTETensor updated_amax_history,
NVTETensor updated_scale,
NVTETensor updated_scale_inv,
const char* amax_compute_algo,
NVTEDType fp8_dtype,
float margin,
cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
#endif // TRANSFORMER_ENGINE_RECIPE_H_
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/recipe.h>
#include <cmath>
#include <string>
#include "../common.h"
#include "../util/logging.h"
namespace transformer_engine {
namespace delayed_scaling_recipe {
namespace {
// amax value to use for updating scaling factor
enum class AmaxComputeAlgo { INVALID, MOST_RECENT, MAX };
const char* dtype_name(DType dtype) {
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, Type,
return TypeInfo<Type>::name;
); // NOLINT(*)
return "";
}
// Maximum representable value of an FP8 dtype
inline float fp8_dtype_max(DType dtype) {
switch (dtype) {
case DType::kFloat8E4M3: return 448;
case DType::kFloat8E5M2: return 57344;
default:
NVTE_ERROR("Expected FP8 dtype, but got ", dtype_name(dtype));
}
return 0;
}
namespace amax_and_scale_update_impl {
// CUDA block size
constexpr size_t bsize = 256;
/* CUDA kernel to update amax history and FP8 scaling factors
*
* Block dims: bsize x 1 x 1
*
* Grid dims: num_scales x 1 x 1
*/
__global__ void __launch_bounds__(bsize)
kernel(const float* amax_history_ptr,
const float* scale_ptr,
const float* scale_inv_ptr,
const unsigned char* scale_inv_mask_ptr,
float* updated_amax_history_ptr,
float* updated_scale_ptr,
float* updated_scale_inv_ptr,
size_t amax_history_length,
size_t amax_history_stride,
AmaxComputeAlgo amax_compute_algo,
float scaled_max) {
const size_t tid = threadIdx.x;
const size_t bid = blockIdx.x;
// Update amax
float amax = 0;
{
// Roll amax history
const auto* amax_history = amax_history_ptr + bid;
auto* updated_amax_history = updated_amax_history_ptr + bid;
const auto last_amax = amax_history[0];
const auto& length = amax_history_length;
const auto& stride = amax_history_stride;
for (size_t off = 0; off < length; off += bsize) {
const size_t i = off + tid;
float a = 0;
if (i < length) {
a = (i < length - 1) ? amax_history[(i+1)*stride] : last_amax;
amax = fmaxf(amax, a);
}
__syncthreads(); // In case roll is in-place
if (i < length) {
updated_amax_history[i*stride] = (i > 0) ? a : 0;
}
}
// Compute amax to use for scaling factor
switch (amax_compute_algo) {
case AmaxComputeAlgo::MOST_RECENT:
amax = last_amax;
break;
case AmaxComputeAlgo::MAX:
{
__shared__ float shared_amax[bsize];
shared_amax[tid] = amax;
__syncthreads();
#pragma unroll
for (size_t off = bsize / 2; off > 0; off /= 2) {
if (tid < off) {
shared_amax[tid] = fmaxf(shared_amax[tid], shared_amax[tid + off]);
}
__syncthreads();
}
amax = shared_amax[tid];
}
break;
default:
amax = 0;
}
}
// Update scale and scale inverse
if (tid == 0) {
// Update scale
float scale;
if (isfinite(amax) && amax > 0) {
scale = scaled_max / amax;
} else {
scale = scale_ptr[bid];
}
updated_scale_ptr[bid] = scale;
// Update scale inverse
float scale_inv;
if (scale_inv_mask_ptr == nullptr || scale_inv_mask_ptr[bid]) {
scale_inv = 1 / scale;
} else {
scale_inv = scale_inv_ptr[bid];
}
updated_scale_inv_ptr[bid] = scale_inv;
}
}
} // namespace amax_and_scale_update_impl
} // namespace
void amax_and_scale_update(const Tensor &amax_history,
const Tensor &scale,
const Tensor &scale_inv,
const Tensor &scale_inv_mask,
Tensor *updated_amax_history_,
Tensor *updated_scale_,
Tensor *updated_scale_inv_,
const std::string &amax_compute_algo,
DType fp8_dtype,
float margin,
cudaStream_t stream) {
auto& updated_amax_history = *updated_amax_history_;
auto& updated_scale = *updated_scale_;
auto& updated_scale_inv = *updated_scale_inv_;
// Number of elements in tensor
auto numel = [] (const Tensor &tensor) -> size_t {
size_t acc = 1;
for (const auto& dim : tensor.data.shape) {
acc *= dim;
}
return acc;
};
// Check tensors
NVTE_CHECK(amax_history.data.shape.size() == 2,
"Found ", amax_history.data.shape.size(), " dims");
const size_t amax_history_length = amax_history.data.shape[0];
const size_t num_scales = amax_history.data.shape[1];
NVTE_CHECK(amax_history.data.dtype == DType::kFloat32,
"Found ", dtype_name(amax_history.data.dtype), ".");
NVTE_CHECK(numel(scale) == num_scales,
"Expected ", num_scales, " elements, ",
"but found ", numel(scale), ".");
NVTE_CHECK(scale.data.dtype == DType::kFloat32,
"Found ", dtype_name(scale.data.dtype), ".");
if (scale_inv_mask.data.dptr != nullptr) {
NVTE_CHECK(numel(scale_inv) == num_scales,
"Expected ", num_scales, " elements, ",
"but found ", numel(scale_inv), ".");
NVTE_CHECK(scale_inv.data.dtype == DType::kFloat32);
NVTE_CHECK(numel(scale_inv_mask) == num_scales,
"Expected ", num_scales, " elements, ",
"but found ", numel(scale_inv_mask), ".");
NVTE_CHECK(scale_inv_mask.data.dtype == DType::kByte,
"Found ", dtype_name(scale_inv_mask.data.dtype), ".");
}
NVTE_CHECK(updated_amax_history.data.shape.size() == 2,
"Found ", updated_amax_history.data.shape.size(), " dims.");
NVTE_CHECK(updated_amax_history.data.shape[0] == amax_history_length,
"Expected ", amax_history_length, ", ",
"but found ", updated_amax_history.data.shape[0]);
NVTE_CHECK(updated_amax_history.data.shape[1] == num_scales,
"Expected ", num_scales, ", ",
"but found ", updated_amax_history.data.shape[1]);
NVTE_CHECK(updated_amax_history.data.dtype == DType::kFloat32,
"Got ", dtype_name(updated_amax_history.data.dtype), ".");
NVTE_CHECK(numel(updated_scale) == num_scales,
"Expected ", num_scales, " elements, ",
"but found ", numel(updated_scale), ".");
NVTE_CHECK(updated_scale.data.dtype == DType::kFloat32,
"Got ", dtype_name(updated_scale.data.dtype), ".");
NVTE_CHECK(numel(updated_scale_inv) == num_scales,
"Expected ", num_scales, " elements, ",
"but found ", numel(updated_scale_inv), ".");
NVTE_CHECK(updated_scale_inv.data.dtype == DType::kFloat32,
"Got ", dtype_name(updated_scale_inv.data.dtype), ".");
// amax value to use for updating scaling factor
AmaxComputeAlgo amax_compute_algo_ = AmaxComputeAlgo::INVALID;
if (amax_compute_algo == "max") {
amax_compute_algo_ = AmaxComputeAlgo::MAX;
} else if (amax_compute_algo == "most_recent") {
amax_compute_algo_ = AmaxComputeAlgo::MOST_RECENT;
} else {
NVTE_ERROR("Unsupported amax compute algorithm (", amax_compute_algo, ")");
}
// Expected maximum value after scale is applied
const float scaled_max = fp8_dtype_max(fp8_dtype) * std::pow(2.f, -margin);
// Launch CUDA kernel
constexpr size_t block_size = amax_and_scale_update_impl::bsize;
const size_t grid_size = num_scales;
amax_and_scale_update_impl::kernel
<<<grid_size, block_size, 0, stream>>>(
static_cast<const float*>(amax_history.data.dptr),
static_cast<const float*>(scale.data.dptr),
static_cast<const float*>(scale_inv.data.dptr),
static_cast<const unsigned char*>(scale_inv_mask.data.dptr),
static_cast<float*>(updated_amax_history.data.dptr),
static_cast<float*>(updated_scale.data.dptr),
static_cast<float*>(updated_scale_inv.data.dptr),
amax_history_length,
num_scales,
amax_compute_algo_,
scaled_max);
NVTE_CHECK_CUDA(cudaGetLastError());
}
} // namespace delayed_scaling_recipe
} // namespace transformer_engine
void nvte_delayed_scaling_recipe_amax_and_scale_update(const NVTETensor amax_history,
const NVTETensor scale,
const NVTETensor scale_inv,
const NVTETensor scale_inv_mask,
NVTETensor updated_amax_history,
NVTETensor updated_scale,
NVTETensor updated_scale_inv,
const char *amax_compute_algo,
NVTEDType fp8_dtype,
float margin,
cudaStream_t stream) {
NVTE_API_CALL(nvte_delayed_scaling_recipe_amax_and_scale_update);
using namespace transformer_engine;
delayed_scaling_recipe::amax_and_scale_update(
*reinterpret_cast<const Tensor*>(amax_history),
*reinterpret_cast<const Tensor*>(scale),
*reinterpret_cast<const Tensor*>(scale_inv),
*reinterpret_cast<const Tensor*>(scale_inv_mask),
reinterpret_cast<Tensor*>(updated_amax_history),
reinterpret_cast<Tensor*>(updated_scale),
reinterpret_cast<Tensor*>(updated_scale_inv),
amax_compute_algo,
static_cast<DType>(fp8_dtype),
margin,
stream);
}
......@@ -17,6 +17,7 @@
#include <transformer_engine/fused_attn.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/layer_norm.h>
#include <transformer_engine/recipe.h>
#include <transformer_engine/rmsnorm.h>
#include <transformer_engine/softmax.h>
#include <transformer_engine/transformer_engine.h>
......
......@@ -1189,54 +1189,31 @@ void te_scaled_upper_triang_masked_softmax_backward(paddle::Tensor &output_grads
softmax_results.stream());
}
__global__ void UpdateFP8MetaKernel(const float *amax, const float *rolled_amax_history,
const bool *non_weight_mask, float *amax_history, float *scale,
float *scale_inv, bool update_weight_scale_inv, float margin,
float fp8_max, size_t history_numel, size_t amax_numel) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= history_numel) {
return;
}
amax_history[idx] = rolled_amax_history[idx];
if (idx < amax_numel) {
float sf = (fp8_max / amax[idx]) / powf(2.0f, margin);
float scale_reg = ((amax[idx] > 0.0f) && isfinite(amax[idx])) ? sf : scale[idx];
scale[idx] = scale_reg;
if (update_weight_scale_inv || non_weight_mask[idx]) scale_inv[idx] = 1.0f / scale_reg;
amax_history[idx] = 0.0f;
}
}
constexpr int BLOCK_SIZE = 512;
void amax_and_scale_update_inplace(paddle::Tensor &amax_history, // NOLINT
paddle::Tensor &scale, // NOLINT
paddle::Tensor &scale_inv, // NOLINT
const paddle::Tensor &non_weight_mask,
bool update_weight_scale_inv, float fp8_max, float margin,
int64_t fp8_dtype,
float margin,
const std::string &amax_compute) {
NVTE_CHECK(amax_compute == "max" || amax_compute == "most_recent");
paddle::Tensor amax;
if (amax_compute == "max") {
amax = amax_history.max({0});
} else {
amax = amax_history.slice(0, 1);
}
const auto rolled_amax_history = amax_history.roll({-1}, {0});
auto size = amax_history.numel();
size_t num_blocks = (size + BLOCK_SIZE - 1) / BLOCK_SIZE;
UpdateFP8MetaKernel<<<num_blocks, BLOCK_SIZE, 0, amax_history.stream()>>>(
amax.data<float>(), rolled_amax_history.data<float>(), non_weight_mask.data<bool>(),
amax_history.data<float>(), scale.data<float>(), scale_inv.data<float>(),
update_weight_scale_inv, margin, fp8_max, amax_history.numel(), amax.numel());
NVTE_CHECK_CUDA(cudaGetLastError());
auto amax_history_ = MakeNvteTensor(amax_history);
auto scale_ = MakeNvteTensor(scale);
auto scale_inv_ = MakeNvteTensor(scale_inv);
const auto non_weight_mask_ = MakeNvteTensor(non_weight_mask);
nvte_delayed_scaling_recipe_amax_and_scale_update(
amax_history_.data(),
scale_.data(),
scale_inv_.data(),
non_weight_mask_.data(),
amax_history_.data(),
scale_.data(),
scale_inv_.data(),
amax_compute.c_str(),
static_cast<NVTEDType>(fp8_dtype),
margin,
amax_history.stream());
}
void update_latest_amax_history_inplace(paddle::Tensor &history, // NOLINT
......@@ -1567,8 +1544,7 @@ PD_BUILD_OP(amax_and_scale_update_inplace)
.SetInplaceMap({{"_amax_history", "amax_history"},
{"_scale", "scale"},
{"_scale_inv", "scale_inv"}})
.Attrs({"update_weight_scale_inv: bool", "fp8_max: float", "margin: float",
"amax_compute: std::string"})
.Attrs({"fp8_dtype: int64_t", "margin: float", "amax_compute: std::string"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::amax_and_scale_update_inplace));
PD_BUILD_OP(update_latest_amax_history_inplace)
......
......@@ -231,16 +231,17 @@ def amax_and_scale_update(
amax_compute = fp8_meta["recipe"].amax_compute_algo
sf_compute = fp8_meta["recipe"].scaling_factor_compute_algo
fp8_meta_tensor_key = "scaling_fwd" if fwd_update else "scaling_bwd"
fp8_max_key = "fp8_max_fwd" if fwd_update else "fp8_max_bwd"
if not callable(amax_compute) and sf_compute is None:
non_weight_mask = fp8_meta[fp8_meta_tensor_key].non_weight_mask
if update_weight_scale_inv:
non_weight_mask = paddle.empty([0])
tex.amax_and_scale_update_inplace(
_amax_history=fp8_meta[fp8_meta_tensor_key].amax_history,
_scale=fp8_meta[fp8_meta_tensor_key].scale,
_scale_inv=fp8_meta[fp8_meta_tensor_key].scale_inv,
non_weight_mask=fp8_meta[fp8_meta_tensor_key].non_weight_mask,
update_weight_scale_inv=update_weight_scale_inv,
fp8_max=fp8_meta[fp8_max_key],
non_weight_mask=non_weight_mask,
fp8_dtype=int(get_fp8_te_dtype(fp8_meta["recipe"], fwd_update)),
margin=float(fp8_meta["recipe"].margin),
amax_compute=amax_compute)
else:
......
......@@ -38,6 +38,7 @@
#include <transformer_engine/fused_rope.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/layer_norm.h>
#include <transformer_engine/recipe.h>
#include <transformer_engine/rmsnorm.h>
#include <transformer_engine/softmax.h>
#include <transformer_engine/transformer_engine.h>
......
......@@ -7,6 +7,10 @@
#include "common.h"
#include "common/common.h"
/***************************************************************************************************
* Attention
**************************************************************************************************/
NVTE_Fused_Attn_Backend get_fused_attn_backend(
const transformer_engine::DType q_dtype,
const transformer_engine::DType kv_dtype,
......@@ -149,9 +153,12 @@ std::vector<at::Tensor> fused_attn_bwd(
c10::optional<at::Tensor> amax_dQKV);
at::Tensor fa_prepare_fwd(at::Tensor qkvi);
at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v);
/***************************************************************************************************
* GEMM
**************************************************************************************************/
void te_gemm(at::Tensor A,
at::Tensor A_scale_inverse,
transformer_engine::DType A_type,
......@@ -202,6 +209,10 @@ void te_atomic_gemm(at::Tensor A,
at::Tensor counter
);
/***************************************************************************************************
* Transpose
**************************************************************************************************/
void fused_cast_transpose(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
......@@ -532,6 +543,21 @@ at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grads
float scale_factor
);
/***************************************************************************************************
* FP8 recipe
**************************************************************************************************/
void fused_amax_and_scale_update(const at::Tensor &amax_history,
const at::Tensor &scale,
const at::Tensor &scale_inv,
const at::Tensor &scale_inv_mask,
at::Tensor updated_amax_history,
at::Tensor updated_scale,
at::Tensor updated_scale_inv,
const std::string& amax_compute_algo,
transformer_engine::DType fp8_dtype,
float margin);
/***************************************************************************************************
* Rotary positional embedding
**************************************************************************************************/
......@@ -557,7 +583,7 @@ at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads,
);
/***************************************************************************************************
* Misc
* Miscellaneous
**************************************************************************************************/
size_t get_cublasLt_version();
......
......@@ -80,6 +80,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention");
m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention");
m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend");
m.def("fused_amax_and_scale_update",
&fused_amax_and_scale_update,
"Update amax history and FP8 scale");
// fused apply rope
m.def("fused_rope_forward", &fused_rope_forward, "Fused Apply RoPE FWD");
......
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
#include <string>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
void fused_amax_and_scale_update(const at::Tensor &amax_history,
const at::Tensor &scale,
const at::Tensor &scale_inv,
const at::Tensor &scale_inv_mask,
at::Tensor updated_amax_history,
at::Tensor updated_scale,
at::Tensor updated_scale_inv,
const std::string& amax_compute_algo,
transformer_engine::DType fp8_dtype,
float margin) {
nvte_delayed_scaling_recipe_amax_and_scale_update(
makeTransformerEngineTensor(amax_history).data(),
makeTransformerEngineTensor(scale).data(),
makeTransformerEngineTensor(scale_inv).data(),
makeTransformerEngineTensor(scale_inv_mask).data(),
makeTransformerEngineTensor(updated_amax_history).data(),
makeTransformerEngineTensor(updated_scale).data(),
makeTransformerEngineTensor(updated_scale_inv).data(),
amax_compute_algo.c_str(),
static_cast<NVTEDType>(fp8_dtype),
margin,
at::cuda::getCurrentCUDAStream());
}
......@@ -625,41 +625,31 @@ def _compute_scaling_factor_inverse(
return torch.where(non_weight_mask, 1.0 / scale, scale_inv)
@torch.jit.script
def _fused_amax_and_scale_update(
amax_history: torch.Tensor,
scale: torch.Tensor,
scale_inv: torch.Tensor,
fp8_max: float,
fp8_dtype: tex.DType,
margin: int,
amax_compute_algo: str,
non_weight_mask: torch.Tensor,
update_weight_scale_inv: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Amax to scale conversion."""
# Get amax from history.
amax_history, amax = _default_get_amax(
"""Update amax history and FP8 scaling factors"""
if update_weight_scale_inv:
non_weight_mask = torch.Tensor()
tex.fused_amax_and_scale_update(
amax_history,
amax_compute_algo,
)
# Calculate new scaling factor.
scale = _default_sf_compute(
amax,
scale,
fp8_max,
margin,
)
# Calculate new inverse of scaling factor.
scale_inv = _compute_scaling_factor_inverse(
scale,
scale_inv,
non_weight_mask,
update_weight_scale_inv,
amax_history,
scale,
scale_inv,
amax_compute_algo,
fp8_dtype,
margin,
)
return amax_history, scale, scale_inv
......@@ -717,7 +707,7 @@ def amax_and_scale_update(
fp8_meta[fp8_meta_tensor_key].amax_history,
fp8_meta[fp8_meta_tensor_key].scale,
fp8_meta[fp8_meta_tensor_key].scale_inv,
fp8_meta[fp8_max_key],
get_fp8_te_dtype(fp8_meta["recipe"], fwd_update),
fp8_meta["recipe"].margin,
fp8_meta["recipe"].amax_compute_algo,
fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment