Unverified Commit 8e3561bf authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Update FP8 scale-inverse in kernels with FP8 output (#1083)



* Perform scale-inv update in cast-transpose kernels
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Perform scale-inv update in cast and activation kernels
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Perform sclae-inv update in LayerNorm and RMSNorm kernels
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Perform scale-inv update after FP8 GEMMs
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fuse casts and scale-inv updates in linear module
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fuse casts and scale-inv updates in layernorm-linear module
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Simplify kernel to update FP8 scale-inv
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix typos
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug amax update in layernorm kernels
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Debug test failures
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug ONNX export

Use quantization scaling factor in ONNX quantize op.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Review suggestion from @ptrendx
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Debug mismatched dtypes
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 5d5fe819
......@@ -229,12 +229,16 @@ __global__ void __launch_bounds__(BLOCK_SIZE)
}
}
// warp tile amax reduce
const CType max_block = reduce_max<BLOCK_SIZE / THREADS_PER_WARP>(amax, warp_id);
if (threadIdx.x == 0) {
if (param.amax != nullptr) {
// Reduce amax over block
if (param.amax != nullptr) {
const CType max_block = reduce_max<BLOCK_SIZE / THREADS_PER_WARP>(amax, warp_id);
if (threadIdx.x == 0) {
atomicMaxFloat(param.amax, max_block);
}
}
// Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && param.scale_inv != nullptr) {
reciprocal<CType>(param.scale_inv, scale);
}
}
......@@ -46,7 +46,8 @@ void fp8_quantize(const Tensor &input, Tensor *output, cudaStream_t stream) {
reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<const fp32 *>(output->scale.dptr),
reinterpret_cast<fp32 *>(output->amax.dptr), N, {},
reinterpret_cast<fp32 *>(output->amax.dptr),
reinterpret_cast<fp32 *>(output->scale_inv.dptr), N, {},
stream);); // NOLINT(*)
); // NOLINT(*)
}
......@@ -68,7 +69,7 @@ void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) {
p.scale_inv = reinterpret_cast<const fp32 *>(input.scale_inv.dptr);
VectorizedUnaryKernelLauncher<nvec, detail::DequantizeParam, detail::dequantize_func>(
reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<OType *>(output->data.dptr), nullptr, nullptr, N, p,
reinterpret_cast<OType *>(output->data.dptr), nullptr, nullptr, nullptr, N, p,
stream);); // NOLINT(*)
); // NOLINT(*)
}
......
......@@ -168,12 +168,12 @@ template <int nvec, bool aligned, typename ComputeType, typename Param,
ComputeType (*OP)(ComputeType, const Param &), typename InputType, typename OutputType>
__launch_bounds__(unary_kernel_threads) __global__
void unary_kernel(const InputType *input, OutputType *output, const ComputeType *scale,
ComputeType *amax, Param p, const size_t N,
ComputeType *amax, ComputeType *scale_inv, Param p, const size_t N,
const size_t num_aligned_elements) {
VectorizedLoader<InputType, nvec, aligned> loader(input, N);
VectorizedStorer<OutputType, nvec, aligned> storer(output, N);
ComputeType max = 0;
ComputeType s = 0;
ComputeType s = 1;
if constexpr (is_fp8<OutputType>::value) {
if (scale != nullptr) s = *scale;
}
......@@ -199,12 +199,18 @@ __launch_bounds__(unary_kernel_threads) __global__
storer.store(tid, N);
}
if constexpr (is_fp8<OutputType>::value) {
/* warp tile amax reduce*/
max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id);
// Reduce amax over block
if (amax != nullptr) {
max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0) {
static_assert(std::is_same<ComputeType, float>::value);
atomicMaxFloat(amax, max);
}
}
if (threadIdx.x == 0 && amax != nullptr) {
static_assert(std::is_same<ComputeType, float>::value);
atomicMaxFloat(amax, max);
// Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) {
reciprocal<ComputeType>(scale_inv, s);
}
}
}
......@@ -214,13 +220,13 @@ template <int nvec, bool aligned, typename ComputeType, typename Param,
typename OutputType>
__launch_bounds__(unary_kernel_threads) __global__
void unary_grad_kernel(const InputTypeGrad *grad, const InputType *input, OutputType *output,
const ComputeType *scale, ComputeType *amax, Param p, const size_t N,
const size_t num_aligned_elements) {
const ComputeType *scale, ComputeType *amax, ComputeType *scale_inv,
Param p, const size_t N, const size_t num_aligned_elements) {
VectorizedLoader<InputType, nvec, aligned> loader(input, N);
VectorizedLoader<InputTypeGrad, nvec, aligned> grad_loader(grad, N);
VectorizedStorer<OutputType, nvec, aligned> storer(output, N);
ComputeType max = 0;
ComputeType s = 0;
ComputeType s = 1;
if constexpr (is_fp8<OutputType>::value) {
if (scale != nullptr) s = *scale;
}
......@@ -248,12 +254,18 @@ __launch_bounds__(unary_kernel_threads) __global__
storer.store(tid, N);
}
if constexpr (is_fp8<OutputType>::value) {
/* warp tile amax reduce*/
max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id);
// Reduce amax over block
if (amax != nullptr) {
max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0) {
static_assert(std::is_same<ComputeType, float>::value);
atomicMaxFloat(amax, max);
}
}
if (threadIdx.x == 0 && amax != nullptr) {
static_assert(std::is_same<ComputeType, float>::value);
atomicMaxFloat(amax, max);
// Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) {
reciprocal<ComputeType>(scale_inv, s);
}
}
}
......@@ -311,7 +323,7 @@ Alignment CheckAlignment(const size_t lead_dim, const int nvec, const T... ptrs)
template <int nvec, typename Param, fp32 (*OP)(const fp32, const Param &), typename InputType,
typename OutputType>
void VectorizedUnaryKernelLauncher(const InputType *input, OutputType *output, const fp32 *scale,
fp32 *amax, const size_t N, const Param params,
fp32 *amax, fp32 *scale_inv, const size_t N, const Param params,
cudaStream_t stream) {
if (N != 0) {
auto align = CheckAlignment(N, nvec, input, output);
......@@ -325,16 +337,16 @@ void VectorizedUnaryKernelLauncher(const InputType *input, OutputType *output, c
switch (align) {
case Alignment::SAME_ALIGNED:
unary_kernel<nvec, true, fp32, Param, OP><<<num_blocks, threads, 0, stream>>>(
input, output, scale, amax, params, N, num_aligned_elements);
input, output, scale, amax, scale_inv, params, N, num_aligned_elements);
break;
case Alignment::SAME_UNALIGNED:
unary_kernel<nvec, false, fp32, Param, OP><<<num_blocks, threads, 0, stream>>>(
input, output, scale, amax, params, N, num_aligned_elements);
input, output, scale, amax, scale_inv, params, N, num_aligned_elements);
break;
case Alignment::DIFFERENT: {
// If the pointers are aligned differently we cannot vectorize
unary_kernel<1, true, fp32, Param, OP>
<<<num_blocks, threads, 0, stream>>>(input, output, scale, amax, params, N, N);
unary_kernel<1, true, fp32, Param, OP><<<num_blocks, threads, 0, stream>>>(
input, output, scale, amax, scale_inv, params, N, N);
break;
}
}
......@@ -345,7 +357,8 @@ template <int nvec, typename Param, fp32 (*OP)(fp32, const Param &), typename In
typename InputTypeGrad, typename OutputType>
void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad, const InputType *input,
OutputType *output, const fp32 *scale, fp32 *amax,
const size_t N, const Param params, cudaStream_t stream) {
fp32 *scale_inv, const size_t N, const Param params,
cudaStream_t stream) {
if (N != 0) {
auto align = CheckAlignment(N, nvec, input, grad, output);
......@@ -358,16 +371,16 @@ void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad, const InputTyp
switch (align) {
case Alignment::SAME_ALIGNED:
unary_grad_kernel<nvec, true, fp32, Param, OP><<<num_blocks, threads, 0, stream>>>(
grad, input, output, scale, amax, params, N, num_aligned_elements);
grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements);
break;
case Alignment::SAME_UNALIGNED:
unary_grad_kernel<nvec, false, fp32, Param, OP><<<num_blocks, threads, 0, stream>>>(
grad, input, output, scale, amax, params, N, num_aligned_elements);
grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements);
break;
case Alignment::DIFFERENT: {
// If the pointers are aligned differently we cannot vectorize
unary_grad_kernel<1, true, fp32, Param, OP>
<<<num_blocks, threads, 0, stream>>>(grad, input, output, scale, amax, params, N, N);
unary_grad_kernel<1, true, fp32, Param, OP><<<num_blocks, threads, 0, stream>>>(
grad, input, output, scale, amax, scale_inv, params, N, N);
break;
}
}
......@@ -379,8 +392,8 @@ template <int nvec, bool aligned, typename ComputeType, typename Param,
typename OutputType>
__launch_bounds__(unary_kernel_threads) __global__
void gated_act_kernel(const InputType *input, OutputType *output, const ComputeType *scale,
ComputeType *amax, const size_t m, const size_t n, const Param p,
const size_t num_aligned_elements) {
ComputeType *amax, ComputeType *scale_inv, const size_t m, const size_t n,
const Param p, const size_t num_aligned_elements) {
const size_t M = num_aligned_elements * m;
for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) {
const size_t id_x = tid % num_aligned_elements;
......@@ -389,7 +402,7 @@ __launch_bounds__(unary_kernel_threads) __global__
VectorizedLoader<InputType, nvec, aligned> loader1(input + id_y * n * 2 + n, n);
VectorizedStorer<OutputType, nvec, aligned> storer(output + id_y * n, n);
ComputeType max = 0;
ComputeType s = 0;
ComputeType s = 1;
if constexpr (is_fp8<OutputType>::value) {
if (scale != nullptr) s = *scale;
}
......@@ -412,12 +425,18 @@ __launch_bounds__(unary_kernel_threads) __global__
storer.store(id_x, n);
if constexpr (is_fp8<OutputType>::value) {
/* warp tile amax reduce*/
max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id);
// Reduce amax over block
if (amax != nullptr) {
max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0) {
static_assert(std::is_same<ComputeType, float>::value);
atomicMaxFloat(amax, max);
}
}
if (threadIdx.x == 0 && amax != nullptr) {
static_assert(std::is_same<ComputeType, float>::value);
atomicMaxFloat(amax, max);
// Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) {
reciprocal<ComputeType>(scale_inv, s);
}
}
}
......@@ -427,8 +446,8 @@ template <int nvec, typename ComputeType, typename Param,
ComputeType (*Activation)(const ComputeType, const Param &), typename InputType,
typename OutputType>
void GatedActivationKernelLauncher(const InputType *input, OutputType *output, const fp32 *scale,
fp32 *amax, const size_t m, const size_t n, const Param &p,
cudaStream_t stream) {
fp32 *amax, fp32 *scale_inv, const size_t m, const size_t n,
const Param &p, cudaStream_t stream) {
if (m != 0 && n != 0) {
size_t num_aligned_elements = get_num_aligned_elements(input, n, nvec, sizeof(InputType));
constexpr size_t threads = unary_kernel_threads;
......@@ -439,18 +458,18 @@ void GatedActivationKernelLauncher(const InputType *input, OutputType *output, c
switch (auto align = CheckAlignment(n, nvec, input, input + n, output)) {
case Alignment::SAME_ALIGNED:
gated_act_kernel<nvec, true, ComputeType, Param, Activation>
<<<num_blocks, threads, 0, stream>>>(input, output, scale, amax, m, n, p,
<<<num_blocks, threads, 0, stream>>>(input, output, scale, amax, scale_inv, m, n, p,
num_aligned_elements);
break;
case Alignment::SAME_UNALIGNED:
gated_act_kernel<nvec, false, ComputeType, Param, Activation>
<<<num_blocks, threads, 0, stream>>>(input, output, scale, amax, m, n, p,
<<<num_blocks, threads, 0, stream>>>(input, output, scale, amax, scale_inv, m, n, p,
num_aligned_elements);
break;
case Alignment::DIFFERENT: {
// If the pointers are aligned differently we cannot vectorize
gated_act_kernel<1, true, ComputeType, Param, Activation>
<<<num_blocks, threads, 0, stream>>>(input, output, scale, amax, m, n, p, n);
<<<num_blocks, threads, 0, stream>>>(input, output, scale, amax, scale_inv, m, n, p, n);
break;
}
}
......
......@@ -852,6 +852,11 @@ __device__ __forceinline__ void reciprocal(T *value_inv, const T value) {
*value_inv = 1 / value;
}
template <>
__device__ __forceinline__ void reciprocal<float>(float *value_inv, const float value) {
*value_inv = __frcp_rn(value);
}
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_UTILS_CUH_
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Helper functions for C++ extensions"""
import functools
from typing import Dict, Optional, Tuple, Union
import torch
import transformer_engine_torch as tex
@functools.lru_cache(maxsize=None)
def empty_tensor() -> torch.Tensor:
"""Get tensor with no entries and no data"""
return torch.Tensor()
def canonicalize_fp8_scales(
*,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
fp8_meta: Optional[tex.FP8TensorMeta] = None,
fp8_meta_index: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None] = None,
allow_multiple_offsets: bool = True,
) -> Tuple[Dict[str, torch.Tensor], Dict[str, int]]:
"""Canonicalize FP8 scaling factors (scale, amax, scale-inverse)
If a scaling factor is not provided, try to access it within the
FP8 meta tensors. Returns dict with tensors and dict with tensor
offsets.
"""
# Default: use provided scales with no offsets
scale_offset = 0
amax_offset = 0
scale_inv_offset = 0
# Get scales from FP8 meta tensors if needed
if (fp8_meta is not None) and any(arg is None for arg in (scale, amax, scale_inv)):
if fp8_meta_index is None:
raise ValueError("Provided `fp8_meta` without corresponding `fp8_meta_index`")
fp8_meta_index = int(fp8_meta_index)
if scale is None:
scale = fp8_meta.scale
scale_offset = fp8_meta_index
if amax is None:
amax = fp8_meta.amax_history
amax_offset = fp8_meta_index
if scale_inv is None:
scale_inv = fp8_meta.scale_inv
scale_inv_offset = fp8_meta_index
# Construct empty tensors if needed
if scale is None:
scale = empty_tensor()
scale_offset = 0
if amax is None:
amax = empty_tensor()
amax_offset = 0
if scale_inv is None:
scale_inv = empty_tensor()
scale_inv_offset = 0
# Force offsets to be the same if needed
if not allow_multiple_offsets and not scale_offset == amax_offset == scale_inv_offset:
if scale_offset != 0:
scale = scale[scale_offset]
scale_offset = 0
if amax_offset != 0:
amax = amax[0][amax_offset]
amax_offset = 0
if scale_inv_offset != 0:
scale_inv = scale_inv[scale_inv_offset]
scale_inv_offset = 0
# Pack tensors and offsets into dicts
tensors = dict(scale=scale, amax=amax, scale_inv=scale_inv)
offsets = dict(
scale_offset=scale_offset,
amax_offset=amax_offset,
scale_inv_offset=scale_inv_offset,
)
return tensors, offsets
......@@ -3,192 +3,235 @@
# See LICENSE for license information.
"""Python interface for activation extensions"""
from typing import Union
from typing import Optional, Union
import torch
import transformer_engine_torch as tex
import transformer_engine_torch as tex
from ._common import canonicalize_fp8_scales
__all__ = ["gelu", "relu", "reglu", "geglu", "swiglu", "qgelu", "srelu"]
def gelu(
inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
fp8_meta_tensor: Optional[tex.FP8TensorMeta],
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None],
otype: tex.DType,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""GeLU with FP8 output"""
empty_tensor = torch.Tensor()
if fp8_meta_tensor is not None:
scale = fp8_meta_tensor.scale
amax_history = fp8_meta_tensor.amax_history
scale_inv = fp8_meta_tensor.scale_inv
else:
scale = empty_tensor
amax_history = empty_tensor
scale_inv = empty_tensor
# Get FP8 scaling factors
fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales(
scale=scale,
amax=amax,
scale_inv=scale_inv,
fp8_meta=fp8_meta_tensor,
fp8_meta_index=fp8_tensor,
allow_multiple_offsets=False,
)
# Launch kernel
return torch.ops.tex_ts.gelu_ts(
inp,
scale,
amax_history,
scale_inv,
fp8_tensor,
fp8_scales["scale"],
fp8_scales["amax"],
fp8_scales["scale_inv"],
fp8_scales_offsets["scale_offset"],
otype,
)
def relu(
inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
fp8_meta_tensor: Optional[tex.FP8TensorMeta],
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None],
otype: tex.DType,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""ReLU with FP8 output"""
empty_tensor = torch.Tensor()
if fp8_meta_tensor is not None:
scale = fp8_meta_tensor.scale
amax_history = fp8_meta_tensor.amax_history
scale_inv = fp8_meta_tensor.scale_inv
else:
scale = empty_tensor
amax_history = empty_tensor
scale_inv = empty_tensor
# Get FP8 scaling factors
fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales(
scale=scale,
amax=amax,
scale_inv=scale_inv,
fp8_meta=fp8_meta_tensor,
fp8_meta_index=fp8_tensor,
allow_multiple_offsets=False,
)
# Launch kernel
return torch.ops.tex_ts.relu_ts(
inp,
scale,
amax_history,
scale_inv,
fp8_tensor,
fp8_scales["scale"],
fp8_scales["amax"],
fp8_scales["scale_inv"],
fp8_scales_offsets["scale_offset"],
otype,
)
def geglu(
inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
fp8_meta_tensor: Optional[tex.FP8TensorMeta],
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None],
otype: tex.DType,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""GeGLU with FP8 output"""
empty_tensor = torch.Tensor()
if fp8_meta_tensor is not None:
scale = fp8_meta_tensor.scale
amax_history = fp8_meta_tensor.amax_history
scale_inv = fp8_meta_tensor.scale_inv
else:
scale = empty_tensor
amax_history = empty_tensor
scale_inv = empty_tensor
# Get FP8 scaling factors
fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales(
scale=scale,
amax=amax,
scale_inv=scale_inv,
fp8_meta=fp8_meta_tensor,
fp8_meta_index=fp8_tensor,
allow_multiple_offsets=False,
)
# Launch kernel
return torch.ops.tex_ts.geglu_ts(
inp,
scale,
amax_history,
scale_inv,
fp8_tensor,
fp8_scales["scale"],
fp8_scales["amax"],
fp8_scales["scale_inv"],
fp8_scales_offsets["scale_offset"],
otype,
)
def reglu(
inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
fp8_meta_tensor: Optional[tex.FP8TensorMeta],
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None],
otype: tex.DType,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""ReGLU with FP8 output"""
empty_tensor = torch.Tensor()
if fp8_meta_tensor is not None:
scale = fp8_meta_tensor.scale
amax_history = fp8_meta_tensor.amax_history
scale_inv = fp8_meta_tensor.scale_inv
else:
scale = empty_tensor
amax_history = empty_tensor
scale_inv = empty_tensor
# Get FP8 scaling factors
fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales(
scale=scale,
amax=amax,
scale_inv=scale_inv,
fp8_meta=fp8_meta_tensor,
fp8_meta_index=fp8_tensor,
allow_multiple_offsets=False,
)
# Launch kernel
return torch.ops.tex_ts.reglu_ts(
inp,
scale,
amax_history,
scale_inv,
fp8_tensor,
fp8_scales["scale"],
fp8_scales["amax"],
fp8_scales["scale_inv"],
fp8_scales_offsets["scale_offset"],
otype,
)
def swiglu(
inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
fp8_meta_tensor: Optional[tex.FP8TensorMeta],
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None],
otype: tex.DType,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""SwiGLU with FP8 output"""
empty_tensor = torch.Tensor()
if fp8_meta_tensor is not None:
scale = fp8_meta_tensor.scale
amax_history = fp8_meta_tensor.amax_history
scale_inv = fp8_meta_tensor.scale_inv
else:
scale = empty_tensor
amax_history = empty_tensor
scale_inv = empty_tensor
# Get FP8 scaling factors
fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales(
scale=scale,
amax=amax,
scale_inv=scale_inv,
fp8_meta=fp8_meta_tensor,
fp8_meta_index=fp8_tensor,
allow_multiple_offsets=False,
)
# Launch kernel
return torch.ops.tex_ts.swiglu_ts(
inp,
scale,
amax_history,
scale_inv,
fp8_tensor,
fp8_scales["scale"],
fp8_scales["amax"],
fp8_scales["scale_inv"],
fp8_scales_offsets["scale_offset"],
otype,
)
def qgelu(
inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
fp8_meta_tensor: Optional[tex.FP8TensorMeta],
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None],
otype: tex.DType,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""QuickGELU with FP8 output"""
empty_tensor = torch.Tensor()
if fp8_meta_tensor is not None:
scale = fp8_meta_tensor.scale
amax_history = fp8_meta_tensor.amax_history
scale_inv = fp8_meta_tensor.scale_inv
else:
scale = empty_tensor
amax_history = empty_tensor
scale_inv = empty_tensor
# Get FP8 scaling factors
fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales(
scale=scale,
amax=amax,
scale_inv=scale_inv,
fp8_meta=fp8_meta_tensor,
fp8_meta_index=fp8_tensor,
allow_multiple_offsets=False,
)
# Launch kernel
return torch.ops.tex_ts.qgelu_ts(
inp,
scale,
amax_history,
scale_inv,
fp8_tensor,
fp8_scales["scale"],
fp8_scales["amax"],
fp8_scales["scale_inv"],
fp8_scales_offsets["scale_offset"],
otype,
)
def srelu(
inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
fp8_meta_tensor: Optional[tex.FP8TensorMeta],
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None],
otype: tex.DType,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""ReLU with FP8 output"""
empty_tensor = torch.Tensor()
if fp8_meta_tensor is not None:
scale = fp8_meta_tensor.scale
amax_history = fp8_meta_tensor.amax_history
scale_inv = fp8_meta_tensor.scale_inv
else:
scale = empty_tensor
amax_history = empty_tensor
scale_inv = empty_tensor
# Get FP8 scaling factors
fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales(
scale=scale,
amax=amax,
scale_inv=scale_inv,
fp8_meta=fp8_meta_tensor,
fp8_meta_index=fp8_tensor,
allow_multiple_offsets=False,
)
# Launch kernel
return torch.ops.tex_ts.srelu_ts(
inp,
scale,
amax_history,
scale_inv,
fp8_tensor,
fp8_scales["scale"],
fp8_scales["amax"],
fp8_scales["scale_inv"],
fp8_scales_offsets["scale_offset"],
otype,
)
......@@ -4,57 +4,91 @@
"""Python interface for cast extensions"""
from typing import Optional, Union
import torch
import transformer_engine_torch as tex
import transformer_engine_torch as tex
from ._common import canonicalize_fp8_scales, empty_tensor
__all__ = ["cast_to_fp8", "cast_from_fp8"]
def cast_to_fp8(
inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
fp8_meta_tensor: Optional[tex.FP8TensorMeta],
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None],
otype: tex.DType,
out: Optional[torch.Tensor] = None,
) -> Optional[torch.Tensor]:
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Cast input to FP8"""
if out is not None:
if inp.nelement() > 0:
torch.ops.tex_ts.cast_to_fp8_noalloc_ts(
inp,
fp8_meta_tensor.scale,
out,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_tensor,
otype,
)
return None
return torch.ops.tex_ts.cast_to_fp8_ts(
inp,
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_tensor,
otype,
# Get FP8 scaling factors
fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales(
scale=scale,
amax=amax,
scale_inv=scale_inv,
fp8_meta=fp8_meta_tensor,
fp8_meta_index=fp8_tensor,
allow_multiple_offsets=False,
)
# Launch FP8 cast kernel
if inp.nelement() == 0:
if out is None:
out = torch.empty_like(inp, dtype=torch.uint8)
elif out is None:
out = torch.ops.tex_ts.cast_to_fp8_ts(
inp,
fp8_scales["scale"],
fp8_scales["amax"],
fp8_scales["scale_inv"],
fp8_scales_offsets["scale_offset"],
otype,
)
else:
torch.ops.tex_ts.cast_to_fp8_noalloc_ts(
inp,
fp8_scales["scale"],
out,
fp8_scales["amax"],
fp8_scales["scale_inv"],
fp8_scales_offsets["scale_offset"],
otype,
)
return out
def cast_from_fp8(
inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
fp8_meta_tensor: Optional[tex.FP8TensorMeta],
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None],
itype: tex.DType,
otype: tex.DType,
scale_inv: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Cast input from FP8"""
# Get scaling factors from FP8 meta tensors if needed
scale_inv_offset = 0
if (fp8_meta_tensor is not None) and (scale_inv is None):
if fp8_tensor is None:
raise ValueError("Provided `fp8_meta_tensor` without corresponding `fp8_tensor`")
scale_inv = fp8_meta_tensor.scale_inv
scale_inv_offset = int(fp8_tensor)
# Construct empty tensors if needed
if scale_inv is None:
scale_inv = empty_tensor()
scale_inv_offset = 0
# Launch FP8 cast kernel
return torch.ops.tex_ts.cast_from_fp8_ts(
inp,
fp8_meta_tensor.scale_inv,
fp8_tensor,
scale_inv,
scale_inv_offset,
itype,
otype,
)
......@@ -4,8 +4,11 @@
"""Python interface for normalization extensions"""
from typing import Optional, Tuple, Union
import torch
import transformer_engine_torch as tex
from ._common import canonicalize_fp8_scales
__all__ = [
......@@ -23,46 +26,55 @@ def layernorm_fwd_fp8(
weight: torch.Tensor,
bias: torch.Tensor,
eps: float,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
fp8_meta_tensor: Optional[tex.FP8TensorMeta],
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None],
otype: tex.DType,
sm_margin: int,
zero_centered_gamma: bool,
ln_out: Optional[torch.Tensor] = None,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""LayerNorm with FP8 output"""
# Get FP8 scaling factors
fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales(
scale=scale,
amax=amax,
scale_inv=scale_inv,
fp8_meta=fp8_meta_tensor,
fp8_meta_index=fp8_tensor,
)
# Launch kernel
if ln_out is not None:
return tex.layernorm_fwd_fp8_noalloc(
inp,
weight,
bias,
eps,
fp8_meta_tensor.scale,
fp8_scales["scale"],
ln_out,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_scales["amax"],
fp8_scales["scale_inv"],
otype,
sm_margin,
zero_centered_gamma,
scale_offset=int(fp8_tensor),
amax_offset=int(fp8_tensor),
scale_inv_offset=int(fp8_tensor),
**fp8_scales_offsets,
)
return tex.layernorm_fwd_fp8(
inp,
weight,
bias,
eps,
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_scales["scale"],
fp8_scales["amax"],
fp8_scales["scale_inv"],
otype,
sm_margin,
zero_centered_gamma,
scale_offset=int(fp8_tensor),
amax_offset=int(fp8_tensor),
scale_inv_offset=int(fp8_tensor),
**fp8_scales_offsets,
)
......@@ -71,26 +83,41 @@ def layernorm_fwd_fp8_inf(
weight: torch.Tensor,
bias: torch.Tensor,
eps: float,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
fp8_meta_tensor: Optional[tex.FP8TensorMeta],
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None],
otype: tex.DType,
sm_margin: int,
zero_centered_gamma,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""LayerNorm with FP8 output.
This version of layernorm_fwd_fp8 is specialized for inference, and returns
only the normalized output.
"""
# Get FP8 scaling factors
fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales(
scale=scale,
amax=amax,
scale_inv=scale_inv,
fp8_meta=fp8_meta_tensor,
fp8_meta_index=fp8_tensor,
allow_multiple_offsets=False,
)
# Launch kernel
ret = torch.ops.tex_ts.layernorm_fwd_fp8_inf_ts(
inp,
weight,
bias,
eps,
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_tensor,
fp8_scales["scale"],
fp8_scales["amax"],
fp8_scales["scale_inv"],
fp8_scales_offsets["scale_offset"],
otype,
sm_margin,
zero_centered_gamma,
......@@ -121,44 +148,53 @@ def rmsnorm_fwd_fp8(
inp: torch.Tensor,
weight: torch.Tensor,
eps: float,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
fp8_meta_tensor: Optional[tex.FP8TensorMeta],
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None],
otype: tex.DType,
sm_margin: int,
zero_centered_gamma: bool,
rmsnorm_out: Optional[torch.Tensor] = None,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""RMSNorm with FP8 output"""
# Get FP8 scaling factors
fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales(
scale=scale,
amax=amax,
scale_inv=scale_inv,
fp8_meta=fp8_meta_tensor,
fp8_meta_index=fp8_tensor,
)
# Launch kernel
if rmsnorm_out is not None:
return tex.rmsnorm_fwd_fp8_noalloc(
inp,
weight,
eps,
fp8_meta_tensor.scale,
fp8_scales["scale"],
rmsnorm_out,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_scales["amax"],
fp8_scales["scale_inv"],
otype,
sm_margin,
zero_centered_gamma,
scale_offset=int(fp8_tensor),
amax_offset=int(fp8_tensor),
scale_inv_offset=int(fp8_tensor),
**fp8_scales_offsets,
)
return tex.rmsnorm_fwd_fp8(
inp,
weight,
eps,
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_scales["scale"],
fp8_scales["amax"],
fp8_scales["scale_inv"],
otype,
sm_margin,
zero_centered_gamma,
scale_offset=int(fp8_tensor),
amax_offset=int(fp8_tensor),
scale_inv_offset=int(fp8_tensor),
**fp8_scales_offsets,
)
......@@ -166,25 +202,40 @@ def rmsnorm_fwd_fp8_inf(
inp: torch.Tensor,
weight: torch.Tensor,
eps: float,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
fp8_meta_tensor: Optional[tex.FP8TensorMeta],
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None],
otype: tex.DType,
sm_margin: int,
zero_centered_gamma,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""RMSNorm with FP8 output.
This version of rmsnorm_fwd_fp8 is specialized for inference, and returns
only the normalized output.
"""
# Get FP8 scaling factors
fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales(
scale=scale,
amax=amax,
scale_inv=scale_inv,
fp8_meta=fp8_meta_tensor,
fp8_meta_index=fp8_tensor,
allow_multiple_offsets=False,
)
# Launch kernel
ret = torch.ops.tex_ts.rmsnorm_fwd_fp8_inf_ts(
inp,
weight,
eps,
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_tensor,
fp8_scales["scale"],
fp8_scales["amax"],
fp8_scales["scale_inv"],
fp8_scales_offsets["scale_offset"],
otype,
sm_margin,
zero_centered_gamma,
......
......@@ -4,9 +4,12 @@
"""Python interface for transpose extensions"""
from typing import List, Optional, Tuple, Union
import torch
import transformer_engine_torch as tex
from ..constants import TE_DType
from ._common import canonicalize_fp8_scales, empty_tensor
__all__ = [
......@@ -20,83 +23,115 @@ __all__ = [
def fp8_cast_transpose_fused(
inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
fp8_meta_tensor: Optional[tex.FP8TensorMeta],
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None],
otype: tex.DType,
cast_out: Optional[torch.Tensor] = None,
transpose_out: Optional[torch.Tensor] = None,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
noop_flag: Optional[torch.Tensor] = None,
) -> Union[Tuple[torch.Tensor, torch.Tensor], None]:
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Cast + Transpose with FP8 output"""
return_outputs = False
# Allocate outputs if needed
if transpose_out is None:
transpose_out = torch.empty(inp.shape[1], inp.shape[0], device="cuda", dtype=torch.uint8)
return_outputs = True
if cast_out is None:
cast_out = torch.empty_like(inp, dtype=torch.uint8)
return_outputs = True
# Get FP8 scaling factors
fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales(
scale=scale,
amax=amax,
scale_inv=scale_inv,
fp8_meta=fp8_meta_tensor,
fp8_meta_index=fp8_tensor,
)
# Construct no-op flag if needed
if noop_flag is None:
noop_flag = torch.Tensor()
noop_flag = empty_tensor()
# Launch kernel if needed
if inp.nelement() > 0:
tex.fused_cast_transpose_noop(
inp,
noop_flag,
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_scales["scale"],
fp8_scales["amax"],
fp8_scales["scale_inv"],
cast_out,
transpose_out,
otype,
scale_offset=int(fp8_tensor),
amax_offset=int(fp8_tensor),
scale_inv_offset=int(fp8_tensor),
**fp8_scales_offsets,
)
if return_outputs:
return cast_out, transpose_out
return None
return cast_out, transpose_out
def fp8_cast_transpose_bgrad_fused(
inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
fp8_meta_tensor: Optional[tex.FP8TensorMeta],
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None],
otype: tex.DType,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Cast + Transpose + BGRAD with FP8 output"""
# Get FP8 scaling factors
fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales(
scale=scale,
amax=amax,
scale_inv=scale_inv,
fp8_meta=fp8_meta_tensor,
fp8_meta_index=fp8_tensor,
)
# Launch kernel
return tex.fused_cast_transpose_bgrad(
inp,
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_scales["scale"],
fp8_scales["amax"],
fp8_scales["scale_inv"],
otype,
scale_offset=int(fp8_tensor),
amax_offset=int(fp8_tensor),
scale_inv_offset=int(fp8_tensor),
**fp8_scales_offsets,
)
def fp8_transpose_bgrad_fused(
inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
fp8_meta_tensor: Optional[tex.FP8TensorMeta],
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None],
otype: tex.DType,
grad_bias_type: torch.dtype,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Transpose + BGRAD with FP8 output"""
# Get FP8 scaling factors
fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales(
scale=scale,
amax=amax,
scale_inv=scale_inv,
fp8_meta=fp8_meta_tensor,
fp8_meta_index=fp8_tensor,
)
# Launch kernel
return tex.fused_fp8_transpose_bgrad(
inp,
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_scales["scale"],
fp8_scales["amax"],
fp8_scales["scale_inv"],
otype,
TE_DType[grad_bias_type],
scale_offset=int(fp8_tensor),
amax_offset=int(fp8_tensor),
scale_inv_offset=int(fp8_tensor),
**fp8_scales_offsets,
)
......@@ -106,18 +141,30 @@ def fp8_cast_transpose_bgrad_dgelu_fused(
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Cast + Transpose + BGRAD + DGELU with FP8 output"""
# Get FP8 scaling factors
fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales(
scale=scale,
amax=amax,
scale_inv=scale_inv,
fp8_meta=fp8_meta_tensor,
fp8_meta_index=fp8_tensor,
)
# Launch kernel
return tex.fused_cast_transpose_bgrad_dgelu(
grad_output,
gelu_input,
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_scales["scale"],
fp8_scales["amax"],
fp8_scales["scale_inv"],
otype,
scale_offset=int(fp8_tensor),
amax_offset=int(fp8_tensor),
scale_inv_offset=int(fp8_tensor),
**fp8_scales_offsets,
)
......
......@@ -117,13 +117,6 @@ class _ToFloat8Func(torch.autograd.Function):
scale_inv: Optional[torch.Tensor] = None,
) -> Float8Tensor:
# Manually compute scale-inverse if needed
if scale is not None and scale_inv is None:
if isinstance(scale, torch.Tensor):
scale_inv = scale.reciprocal()
else:
scale_inv = 1 / scale
# Extract data from FP8 meta tensors if provided
if fp8_meta is not None:
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
......@@ -138,9 +131,6 @@ class _ToFloat8Func(torch.autograd.Function):
scale = fp8_meta[fp8_meta_key].scale[fp8_meta_index]
if amax is None:
amax = fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index]
if scale_inv is None:
scale_inv = fp8_meta[fp8_meta_key].scale_inv[fp8_meta_index]
scale_inv = scale_inv.detach().view(1).clone()
# Check input tensor
tensor = tensor.contiguous().cuda().detach()
......@@ -163,8 +153,9 @@ class _ToFloat8Func(torch.autograd.Function):
# Check scale-inverse
if scale_inv is None:
scale_inv = scale.reciprocal()
scale_inv = scale_inv.to(device=tensor.device, dtype=torch.float32)
scale_inv = torch.empty([1], dtype=torch.float32, device=tensor.device)
else:
scale_inv = scale_inv.to(device=tensor.device, dtype=torch.float32)
# Check amax
if amax is None:
......@@ -737,19 +728,9 @@ class Float8Tensor(torch.Tensor):
self._fp8_dtype,
cast_out=data,
transpose_out=transpose,
scale_inv=self._scale_inv,
noop_flag=noop_flag,
)
scale = fp8_meta.scale[fp8_meta_index : fp8_meta_index + 1]
scale_inv = self._scale_inv
if noop_flag is None:
torch.reciprocal(scale, out=scale_inv)
else:
torch.where(
noop_flag.bool(),
scale_inv,
scale.reciprocal(),
out=scale_inv,
)
self._transpose_invalid = False
@torch.no_grad()
......@@ -853,7 +834,6 @@ class Float8Tensor(torch.Tensor):
fp8_meta_index = dst._fp8_meta_index
scale = dst._fp8_meta[fp8_meta_key].scale[fp8_meta_index]
amax = dst._fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index]
dst._scale_inv.copy_(scale.detach().reciprocal())
# Cast to FP8
if not dst._data.is_contiguous():
......
......@@ -52,6 +52,9 @@ def _apply_normalization(
fwd_ln_sm_margin: int,
zero_centered_gamma: bool,
is_grad_enabled: bool,
fp8_scale: Optional[torch.Tensor] = None,
fp8_amax: Optional[torch.Tensor] = None,
fp8_scale_inv: Optional[torch.Tensor] = None,
):
normalization_func = _get_normalization_func(normalization, fp8_out, is_grad_enabled, True)
......@@ -70,6 +73,9 @@ def _apply_normalization(
fp8_dtype_forward,
fwd_ln_sm_margin,
zero_centered_gamma,
scale=fp8_scale,
amax=fp8_amax,
scale_inv=fp8_scale_inv,
**output_kwarg,
)
else:
......@@ -82,6 +88,9 @@ def _apply_normalization(
fp8_dtype_forward,
fwd_ln_sm_margin,
zero_centered_gamma,
scale=fp8_scale,
amax=fp8_amax,
scale_inv=fp8_scale_inv,
),
None,
None,
......
......@@ -46,6 +46,7 @@ from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ._common import _apply_normalization, _noop_cat
from ..float8_tensor import Float8Tensor
from ..export import is_in_onnx_export_mode
__all__ = ["LayerNormLinear"]
......@@ -126,8 +127,13 @@ class _LayerNormLinear(torch.autograd.Function):
inputmat, dtype=ln_out_dtype, memory_format=torch.contiguous_format
)
# Objects for FP8 cast
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
ln_out_scale_inv = None
if fp8:
ln_out_scale_inv = torch.empty([1], dtype=torch.float32, device=inputmat.device)
# Launch normalization kernel
ln_out, mu, rsigma = _apply_normalization(
inputmat,
ln_out,
......@@ -140,6 +146,7 @@ class _LayerNormLinear(torch.autograd.Function):
fwd_ln_sm_margin,
zero_centered_gamma,
is_grad_enabled,
fp8_scale_inv=ln_out_scale_inv,
)
# Column Parallel Linear
......@@ -172,6 +179,7 @@ class _LayerNormLinear(torch.autograd.Function):
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
out=ln_out_fp8,
scale_inv=ln_out_scale_inv,
)
ln_out = torch.empty_like(ln_out_fp8)
else:
......@@ -180,6 +188,7 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
scale_inv=ln_out_scale_inv,
)
if ln_out_gathered:
rank = torch.distributed.get_rank(tp_group)
......@@ -199,6 +208,18 @@ class _LayerNormLinear(torch.autograd.Function):
assert isinstance(weight_fp8, Float8Tensor)
# Hack for ONNX export
# Note: ONNX models are represented as a graph of tensor
# operations, so the in-place scale-inv update doesn't fit
# very well. We work around this by making it look like
# the scale-inv tensor is initialized with a copy.
# Note: ONNX export expects FP8 scales can be represented
# with constant ops. However, copying into a buffer
# involves an expand op for array broadcasting. We work
# around this by filling the buffer instead.
if is_in_onnx_export_mode():
ln_out_scale_inv.fill_(ln_out_scale_inv.item())
if fp8_meta["recipe"].fp8_mha:
out_index, meta_tensor, output_te_dtype, output_dtype = (
tex.FP8FwdTensors.GEMM1_OUTPUT,
......@@ -219,8 +240,8 @@ class _LayerNormLinear(torch.autograd.Function):
0,
weight_fp8._fp8_dtype,
ln_out_total,
fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_INPUT,
ln_out_scale_inv,
0,
fp8_dtype_forward,
output_dtype,
get_workspace(),
......@@ -306,7 +327,7 @@ class _LayerNormLinear(torch.autograd.Function):
weight_fp8,
weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None,
ln_out if weight.requires_grad else None,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
ln_out_scale_inv,
)
ctx.activation_dtype = activation_dtype
......@@ -377,7 +398,7 @@ class _LayerNormLinear(torch.autograd.Function):
weight_fp8,
main_grad,
ln_out,
fwd_scale_inverses,
ln_out_scale_inv,
) = ctx.saved_tensors
# Gather intermediate/activation tensors if needed
......@@ -570,8 +591,8 @@ class _LayerNormLinear(torch.autograd.Function):
ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward)
wgrad, _ = tex.fp8_gemm(
ln_out_total_t,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_INPUT,
ln_out_scale_inv,
0,
fp8_dtype_forward,
(
grad_output_t._data
......@@ -596,8 +617,8 @@ class _LayerNormLinear(torch.autograd.Function):
else:
ln_out_total_c = torch.ops.tex_ts.cast_from_fp8_ts(
ln_out_total,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_INPUT,
ln_out_scale_inv,
0,
fp8_dtype_forward,
TE_DType[ctx.activation_dtype],
)
......
......@@ -48,6 +48,7 @@ from ..constants import GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ..float8_tensor import Float8Tensor
from ..export import is_in_onnx_export_mode
__all__ = ["Linear"]
......@@ -103,10 +104,12 @@ class _Linear(torch.autograd.Function):
inputmat = cast_if_needed(inputmat, activation_dtype)
inputmat_t = None
inputmat_no_fp8 = inputmat
inputmat_scale_inv = None
if fp8:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if isinstance(inputmat, Float8Tensor):
inputmat_scale_inv = inputmat._scale_inv
if (
not fp8_meta["recipe"].override_linear_precision.wgrad
and is_grad_enabled
......@@ -116,6 +119,7 @@ class _Linear(torch.autograd.Function):
# FP8 input for forward, FP8 input transpose for backward wgrad
inputmat_t = inputmat.transpose_2d()
else:
inputmat_scale_inv = torch.empty([1], dtype=torch.float32, device=inputmat.device)
if (
not fp8_meta["recipe"].override_linear_precision.wgrad
and is_grad_enabled
......@@ -128,6 +132,7 @@ class _Linear(torch.autograd.Function):
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
scale_inv=inputmat_scale_inv,
)
else:
# FP8 input for forward
......@@ -136,8 +141,21 @@ class _Linear(torch.autograd.Function):
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
scale_inv=inputmat_scale_inv,
)
# Hack for ONNX export
# Note: ONNX models are represented as a graph of tensor
# operations, so the in-place scale-inv update doesn't fit
# very well. We work around this by making it look like
# the scale-inv tensor is initialized with a copy.
# Note: ONNX export expects FP8 scales can be represented
# with constant ops. However, copying into a buffer
# involves an expand op for array broadcasting. We work
# around this by filling the buffer instead.
if is_in_onnx_export_mode():
inputmat_scale_inv.fill_(inputmat_scale_inv.item())
# Column Parallel Linear
if parallel_mode == "column" and sequence_parallel:
inputmat_total, _ = gather_along_first_dim(inputmat, tp_group)
......@@ -206,8 +224,8 @@ class _Linear(torch.autograd.Function):
if isinstance(inputmat_total, Float8Tensor)
else inputmat_total
),
fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_INPUT,
inputmat_scale_inv,
0,
fp8_dtype_forward,
proj_out_pttype,
get_workspace(),
......@@ -312,10 +330,10 @@ class _Linear(torch.autograd.Function):
ctx.save_for_backward(
saved_inputmat,
saved_inputmat_t,
inputmat_scale_inv,
weight,
weight_fp8,
weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
)
ctx.activation_dtype = activation_dtype
......@@ -364,10 +382,10 @@ class _Linear(torch.autograd.Function):
(
inputmat,
inputmat_t,
inputmat_scale_inv,
weight,
weight_fp8,
main_grad,
fwd_scale_inverses,
) = ctx.saved_tensors
# Gather intermediate/activation tensors if needed
......@@ -520,8 +538,8 @@ class _Linear(torch.autograd.Function):
if isinstance(inputmat_t_total, Float8Tensor)
else inputmat_t_total
),
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_INPUT,
inputmat_scale_inv,
0,
fp8_dtype_forward,
grad_output_t,
ctx.fp8_meta["scaling_bwd"].scale_inv,
......
......@@ -74,7 +74,7 @@ def is_dtype_bf16(t):
return t.type().scalarType() == "BFloat16"
def quantize(g, inputs, scale_inv, fp8_tensor):
def quantize(g, inputs, scale, fp8_tensor):
"""Helper Function for Quantization"""
output_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inputs)
......@@ -83,7 +83,7 @@ def quantize(g, inputs, scale_inv, fp8_tensor):
if not is_dtype_fp32(inputs):
inputs = g.op("Cast", inputs, to_i=_C_onnx.TensorProtoDataType.FLOAT)
scale = g.op("Constant", value_t=torch.tensor(scale_inv[fp8_tensor]))
scale = g.op("Constant", value_t=torch.tensor(1 / scale[fp8_tensor]))
q_op = g.op(make_op_name("TRT_FP8QuantizeLinear"), inputs, scale).setType(
inputs.type().with_dtype(torch.uint8).with_sizes(output_shape)
)
......@@ -124,18 +124,18 @@ def compute_in_fp32(g, inp, subgraph, *args, **kwargs):
return sg_out
@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i")
@symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i")
def onnx_cast_to_fp8(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for cast_to_fp8"""
# pylint: disable=unused-argument
return quantize(g, inputs, scale_inv, fp8_tensor)
return quantize(g, inputs, scale, fp8_tensor)
@symbolic_helper.parse_args("v", "v", "v", "v", "fs", "i", "i")
@symbolic_helper.parse_args("v", "fs", "v", "v", "v", "i", "i")
def onnx_cast_to_fp8_noalloc(g, inputs, scale, output, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for cast_to_fp8_noalloc"""
# pylint: disable=unused-argument
return quantize(g, inputs, scale_inv, fp8_tensor)
return quantize(g, inputs, scale, fp8_tensor)
@symbolic_helper.parse_args("v", "fs", "i", "i", "i")
......@@ -145,25 +145,25 @@ def onnx_cast_from_fp8(g, inputs, scale_inv, fp8_tensor, itype, otype):
return dequantize(g, inputs, scale_inv, fp8_tensor, otype)
@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i")
@symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i")
def onnx_fp8_gelu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for fp8_gelu"""
# pylint: disable=unused-argument
# TE computes GELU using float32 precision so wrap the GELU subgraph with
# conversion to/from float32.
gelu = compute_in_fp32(g, inputs, torch.onnx.symbolic_opset9.gelu, "tanh")
if scale_inv:
gelu = quantize(g, gelu, scale_inv, fp8_tensor)
if scale:
gelu = quantize(g, gelu, scale, fp8_tensor)
return gelu
@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i")
@symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i")
def onnx_fp8_relu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for fp8_relu"""
# pylint: disable=unused-argument
relu = compute_in_fp32(g, inputs, torch.onnx.symbolic_opset9.relu)
if scale_inv:
relu = quantize(g, relu, scale_inv, fp8_tensor)
if scale:
relu = quantize(g, relu, scale, fp8_tensor)
return relu
......@@ -178,13 +178,13 @@ def onnx_swiglu(g: jit_utils.GraphContext, inp, dim):
return g.op("Mul", g.op("Sigmoid", first), second)
@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i")
@symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i")
def onnx_fp8_swiglu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for fp8_swiglu"""
# pylint: disable=unused-argument
swiglu = compute_in_fp32(g, inputs, onnx_swiglu, 1)
if scale_inv:
swiglu = quantize(g, swiglu, scale_inv, fp8_tensor)
if scale:
swiglu = quantize(g, swiglu, scale, fp8_tensor)
return swiglu
......@@ -199,13 +199,13 @@ def onnx_reglu(g: jit_utils.GraphContext, inp, dim):
return g.op("Mul", g.op("Relu", first), second)
@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i")
@symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i")
def onnx_fp8_reglu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for fp8_reglu"""
# pylint: disable=unused-argument
reglu = compute_in_fp32(g, inputs, onnx_reglu, 1)
if scale_inv:
reglu = quantize(g, reglu, scale_inv, fp8_tensor)
if scale:
reglu = quantize(g, reglu, scale, fp8_tensor)
return reglu
......@@ -221,13 +221,13 @@ def onnx_geglu(g: jit_utils.GraphContext, inp, dim):
return g.op("Mul", first_gelu, second)
@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i")
@symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i")
def onnx_fp8_geglu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for fp8_geglu"""
# pylint: disable=unused-argument
geglu = compute_in_fp32(g, inputs, onnx_geglu, 1)
if scale_inv:
geglu = quantize(g, geglu, scale_inv, fp8_tensor)
if scale:
geglu = quantize(g, geglu, scale, fp8_tensor)
return geglu
......@@ -245,7 +245,7 @@ def onnx_fp8_geglu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
"v",
"fs",
"i",
"fs",
"v",
"v",
"i",
"v",
......@@ -330,7 +330,7 @@ def _ones_like(g, inp, dtype):
return one
@symbolic_helper.parse_args("v", "v", "v", "f", "v", "v", "fs", "i", "i", "i", "b")
@symbolic_helper.parse_args("v", "v", "v", "f", "fs", "v", "v", "i", "i", "i", "b")
def onnx_layernorm_fwd_fp8(
g,
inputs,
......@@ -355,7 +355,7 @@ def onnx_layernorm_fwd_fp8(
bias = g.op("Cast", bias, to_i=inp_dtype)
ln = onnx_layernorm_fwd(g, inputs, weight, bias, eps, sm_margin, zero_centered_gamma)
fp8_ln = quantize(g, ln, scale_inv, fp8_tensor)
fp8_ln = quantize(g, ln, scale, fp8_tensor)
return fp8_ln
......@@ -391,7 +391,7 @@ def onnx_layernorm_fwd(g, inputs, weight, bias, eps, sm_margin, zero_centered_ga
return ln
@symbolic_helper.parse_args("v", "v", "f", "v", "v", "fs", "i", "i", "i", "b")
@symbolic_helper.parse_args("v", "v", "f", "fs", "v", "v", "i", "i", "i", "b")
def onnx_rmsnorm_fwd_fp8(
g,
inputs,
......@@ -413,7 +413,7 @@ def onnx_rmsnorm_fwd_fp8(
weight = g.op("Cast", weight, to_i=inp_dtype)
ln = onnx_rmsnorm_fwd(g, inputs, weight, eps, sm_margin, zero_centered_gamma)
fp8_ln = quantize(g, ln, scale_inv, fp8_tensor)
fp8_ln = quantize(g, ln, scale, fp8_tensor)
return fp8_ln
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment