Unverified Commit 8e3561bf authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Update FP8 scale-inverse in kernels with FP8 output (#1083)



* Perform scale-inv update in cast-transpose kernels
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Perform scale-inv update in cast and activation kernels
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Perform sclae-inv update in LayerNorm and RMSNorm kernels
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Perform scale-inv update after FP8 GEMMs
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fuse casts and scale-inv updates in linear module
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fuse casts and scale-inv updates in layernorm-linear module
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Simplify kernel to update FP8 scale-inv
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix typos
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug amax update in layernorm kernels
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Debug test failures
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug ONNX export

Use quantization scaling factor in ONNX quantize op.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Review suggestion from @ptrendx
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Debug mismatched dtypes
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 5d5fe819
...@@ -229,12 +229,16 @@ __global__ void __launch_bounds__(BLOCK_SIZE) ...@@ -229,12 +229,16 @@ __global__ void __launch_bounds__(BLOCK_SIZE)
} }
} }
// warp tile amax reduce // Reduce amax over block
const CType max_block = reduce_max<BLOCK_SIZE / THREADS_PER_WARP>(amax, warp_id); if (param.amax != nullptr) {
const CType max_block = reduce_max<BLOCK_SIZE / THREADS_PER_WARP>(amax, warp_id);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
if (param.amax != nullptr) {
atomicMaxFloat(param.amax, max_block); atomicMaxFloat(param.amax, max_block);
} }
} }
// Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && param.scale_inv != nullptr) {
reciprocal<CType>(param.scale_inv, scale);
}
} }
...@@ -46,7 +46,8 @@ void fp8_quantize(const Tensor &input, Tensor *output, cudaStream_t stream) { ...@@ -46,7 +46,8 @@ void fp8_quantize(const Tensor &input, Tensor *output, cudaStream_t stream) {
reinterpret_cast<const IType *>(input.data.dptr), reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<OType *>(output->data.dptr), reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<const fp32 *>(output->scale.dptr), reinterpret_cast<const fp32 *>(output->scale.dptr),
reinterpret_cast<fp32 *>(output->amax.dptr), N, {}, reinterpret_cast<fp32 *>(output->amax.dptr),
reinterpret_cast<fp32 *>(output->scale_inv.dptr), N, {},
stream);); // NOLINT(*) stream);); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
} }
...@@ -68,7 +69,7 @@ void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { ...@@ -68,7 +69,7 @@ void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) {
p.scale_inv = reinterpret_cast<const fp32 *>(input.scale_inv.dptr); p.scale_inv = reinterpret_cast<const fp32 *>(input.scale_inv.dptr);
VectorizedUnaryKernelLauncher<nvec, detail::DequantizeParam, detail::dequantize_func>( VectorizedUnaryKernelLauncher<nvec, detail::DequantizeParam, detail::dequantize_func>(
reinterpret_cast<const IType *>(input.data.dptr), reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<OType *>(output->data.dptr), nullptr, nullptr, N, p, reinterpret_cast<OType *>(output->data.dptr), nullptr, nullptr, nullptr, N, p,
stream);); // NOLINT(*) stream);); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
} }
......
...@@ -168,12 +168,12 @@ template <int nvec, bool aligned, typename ComputeType, typename Param, ...@@ -168,12 +168,12 @@ template <int nvec, bool aligned, typename ComputeType, typename Param,
ComputeType (*OP)(ComputeType, const Param &), typename InputType, typename OutputType> ComputeType (*OP)(ComputeType, const Param &), typename InputType, typename OutputType>
__launch_bounds__(unary_kernel_threads) __global__ __launch_bounds__(unary_kernel_threads) __global__
void unary_kernel(const InputType *input, OutputType *output, const ComputeType *scale, void unary_kernel(const InputType *input, OutputType *output, const ComputeType *scale,
ComputeType *amax, Param p, const size_t N, ComputeType *amax, ComputeType *scale_inv, Param p, const size_t N,
const size_t num_aligned_elements) { const size_t num_aligned_elements) {
VectorizedLoader<InputType, nvec, aligned> loader(input, N); VectorizedLoader<InputType, nvec, aligned> loader(input, N);
VectorizedStorer<OutputType, nvec, aligned> storer(output, N); VectorizedStorer<OutputType, nvec, aligned> storer(output, N);
ComputeType max = 0; ComputeType max = 0;
ComputeType s = 0; ComputeType s = 1;
if constexpr (is_fp8<OutputType>::value) { if constexpr (is_fp8<OutputType>::value) {
if (scale != nullptr) s = *scale; if (scale != nullptr) s = *scale;
} }
...@@ -199,12 +199,18 @@ __launch_bounds__(unary_kernel_threads) __global__ ...@@ -199,12 +199,18 @@ __launch_bounds__(unary_kernel_threads) __global__
storer.store(tid, N); storer.store(tid, N);
} }
if constexpr (is_fp8<OutputType>::value) { if constexpr (is_fp8<OutputType>::value) {
/* warp tile amax reduce*/ // Reduce amax over block
max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id); if (amax != nullptr) {
max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0) {
static_assert(std::is_same<ComputeType, float>::value);
atomicMaxFloat(amax, max);
}
}
if (threadIdx.x == 0 && amax != nullptr) { // Update scale-inverse
static_assert(std::is_same<ComputeType, float>::value); if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) {
atomicMaxFloat(amax, max); reciprocal<ComputeType>(scale_inv, s);
} }
} }
} }
...@@ -214,13 +220,13 @@ template <int nvec, bool aligned, typename ComputeType, typename Param, ...@@ -214,13 +220,13 @@ template <int nvec, bool aligned, typename ComputeType, typename Param,
typename OutputType> typename OutputType>
__launch_bounds__(unary_kernel_threads) __global__ __launch_bounds__(unary_kernel_threads) __global__
void unary_grad_kernel(const InputTypeGrad *grad, const InputType *input, OutputType *output, void unary_grad_kernel(const InputTypeGrad *grad, const InputType *input, OutputType *output,
const ComputeType *scale, ComputeType *amax, Param p, const size_t N, const ComputeType *scale, ComputeType *amax, ComputeType *scale_inv,
const size_t num_aligned_elements) { Param p, const size_t N, const size_t num_aligned_elements) {
VectorizedLoader<InputType, nvec, aligned> loader(input, N); VectorizedLoader<InputType, nvec, aligned> loader(input, N);
VectorizedLoader<InputTypeGrad, nvec, aligned> grad_loader(grad, N); VectorizedLoader<InputTypeGrad, nvec, aligned> grad_loader(grad, N);
VectorizedStorer<OutputType, nvec, aligned> storer(output, N); VectorizedStorer<OutputType, nvec, aligned> storer(output, N);
ComputeType max = 0; ComputeType max = 0;
ComputeType s = 0; ComputeType s = 1;
if constexpr (is_fp8<OutputType>::value) { if constexpr (is_fp8<OutputType>::value) {
if (scale != nullptr) s = *scale; if (scale != nullptr) s = *scale;
} }
...@@ -248,12 +254,18 @@ __launch_bounds__(unary_kernel_threads) __global__ ...@@ -248,12 +254,18 @@ __launch_bounds__(unary_kernel_threads) __global__
storer.store(tid, N); storer.store(tid, N);
} }
if constexpr (is_fp8<OutputType>::value) { if constexpr (is_fp8<OutputType>::value) {
/* warp tile amax reduce*/ // Reduce amax over block
max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id); if (amax != nullptr) {
max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0) {
static_assert(std::is_same<ComputeType, float>::value);
atomicMaxFloat(amax, max);
}
}
if (threadIdx.x == 0 && amax != nullptr) { // Update scale-inverse
static_assert(std::is_same<ComputeType, float>::value); if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) {
atomicMaxFloat(amax, max); reciprocal<ComputeType>(scale_inv, s);
} }
} }
} }
...@@ -311,7 +323,7 @@ Alignment CheckAlignment(const size_t lead_dim, const int nvec, const T... ptrs) ...@@ -311,7 +323,7 @@ Alignment CheckAlignment(const size_t lead_dim, const int nvec, const T... ptrs)
template <int nvec, typename Param, fp32 (*OP)(const fp32, const Param &), typename InputType, template <int nvec, typename Param, fp32 (*OP)(const fp32, const Param &), typename InputType,
typename OutputType> typename OutputType>
void VectorizedUnaryKernelLauncher(const InputType *input, OutputType *output, const fp32 *scale, void VectorizedUnaryKernelLauncher(const InputType *input, OutputType *output, const fp32 *scale,
fp32 *amax, const size_t N, const Param params, fp32 *amax, fp32 *scale_inv, const size_t N, const Param params,
cudaStream_t stream) { cudaStream_t stream) {
if (N != 0) { if (N != 0) {
auto align = CheckAlignment(N, nvec, input, output); auto align = CheckAlignment(N, nvec, input, output);
...@@ -325,16 +337,16 @@ void VectorizedUnaryKernelLauncher(const InputType *input, OutputType *output, c ...@@ -325,16 +337,16 @@ void VectorizedUnaryKernelLauncher(const InputType *input, OutputType *output, c
switch (align) { switch (align) {
case Alignment::SAME_ALIGNED: case Alignment::SAME_ALIGNED:
unary_kernel<nvec, true, fp32, Param, OP><<<num_blocks, threads, 0, stream>>>( unary_kernel<nvec, true, fp32, Param, OP><<<num_blocks, threads, 0, stream>>>(
input, output, scale, amax, params, N, num_aligned_elements); input, output, scale, amax, scale_inv, params, N, num_aligned_elements);
break; break;
case Alignment::SAME_UNALIGNED: case Alignment::SAME_UNALIGNED:
unary_kernel<nvec, false, fp32, Param, OP><<<num_blocks, threads, 0, stream>>>( unary_kernel<nvec, false, fp32, Param, OP><<<num_blocks, threads, 0, stream>>>(
input, output, scale, amax, params, N, num_aligned_elements); input, output, scale, amax, scale_inv, params, N, num_aligned_elements);
break; break;
case Alignment::DIFFERENT: { case Alignment::DIFFERENT: {
// If the pointers are aligned differently we cannot vectorize // If the pointers are aligned differently we cannot vectorize
unary_kernel<1, true, fp32, Param, OP> unary_kernel<1, true, fp32, Param, OP><<<num_blocks, threads, 0, stream>>>(
<<<num_blocks, threads, 0, stream>>>(input, output, scale, amax, params, N, N); input, output, scale, amax, scale_inv, params, N, N);
break; break;
} }
} }
...@@ -345,7 +357,8 @@ template <int nvec, typename Param, fp32 (*OP)(fp32, const Param &), typename In ...@@ -345,7 +357,8 @@ template <int nvec, typename Param, fp32 (*OP)(fp32, const Param &), typename In
typename InputTypeGrad, typename OutputType> typename InputTypeGrad, typename OutputType>
void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad, const InputType *input, void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad, const InputType *input,
OutputType *output, const fp32 *scale, fp32 *amax, OutputType *output, const fp32 *scale, fp32 *amax,
const size_t N, const Param params, cudaStream_t stream) { fp32 *scale_inv, const size_t N, const Param params,
cudaStream_t stream) {
if (N != 0) { if (N != 0) {
auto align = CheckAlignment(N, nvec, input, grad, output); auto align = CheckAlignment(N, nvec, input, grad, output);
...@@ -358,16 +371,16 @@ void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad, const InputTyp ...@@ -358,16 +371,16 @@ void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad, const InputTyp
switch (align) { switch (align) {
case Alignment::SAME_ALIGNED: case Alignment::SAME_ALIGNED:
unary_grad_kernel<nvec, true, fp32, Param, OP><<<num_blocks, threads, 0, stream>>>( unary_grad_kernel<nvec, true, fp32, Param, OP><<<num_blocks, threads, 0, stream>>>(
grad, input, output, scale, amax, params, N, num_aligned_elements); grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements);
break; break;
case Alignment::SAME_UNALIGNED: case Alignment::SAME_UNALIGNED:
unary_grad_kernel<nvec, false, fp32, Param, OP><<<num_blocks, threads, 0, stream>>>( unary_grad_kernel<nvec, false, fp32, Param, OP><<<num_blocks, threads, 0, stream>>>(
grad, input, output, scale, amax, params, N, num_aligned_elements); grad, input, output, scale, amax, scale_inv, params, N, num_aligned_elements);
break; break;
case Alignment::DIFFERENT: { case Alignment::DIFFERENT: {
// If the pointers are aligned differently we cannot vectorize // If the pointers are aligned differently we cannot vectorize
unary_grad_kernel<1, true, fp32, Param, OP> unary_grad_kernel<1, true, fp32, Param, OP><<<num_blocks, threads, 0, stream>>>(
<<<num_blocks, threads, 0, stream>>>(grad, input, output, scale, amax, params, N, N); grad, input, output, scale, amax, scale_inv, params, N, N);
break; break;
} }
} }
...@@ -379,8 +392,8 @@ template <int nvec, bool aligned, typename ComputeType, typename Param, ...@@ -379,8 +392,8 @@ template <int nvec, bool aligned, typename ComputeType, typename Param,
typename OutputType> typename OutputType>
__launch_bounds__(unary_kernel_threads) __global__ __launch_bounds__(unary_kernel_threads) __global__
void gated_act_kernel(const InputType *input, OutputType *output, const ComputeType *scale, void gated_act_kernel(const InputType *input, OutputType *output, const ComputeType *scale,
ComputeType *amax, const size_t m, const size_t n, const Param p, ComputeType *amax, ComputeType *scale_inv, const size_t m, const size_t n,
const size_t num_aligned_elements) { const Param p, const size_t num_aligned_elements) {
const size_t M = num_aligned_elements * m; const size_t M = num_aligned_elements * m;
for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) { for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) {
const size_t id_x = tid % num_aligned_elements; const size_t id_x = tid % num_aligned_elements;
...@@ -389,7 +402,7 @@ __launch_bounds__(unary_kernel_threads) __global__ ...@@ -389,7 +402,7 @@ __launch_bounds__(unary_kernel_threads) __global__
VectorizedLoader<InputType, nvec, aligned> loader1(input + id_y * n * 2 + n, n); VectorizedLoader<InputType, nvec, aligned> loader1(input + id_y * n * 2 + n, n);
VectorizedStorer<OutputType, nvec, aligned> storer(output + id_y * n, n); VectorizedStorer<OutputType, nvec, aligned> storer(output + id_y * n, n);
ComputeType max = 0; ComputeType max = 0;
ComputeType s = 0; ComputeType s = 1;
if constexpr (is_fp8<OutputType>::value) { if constexpr (is_fp8<OutputType>::value) {
if (scale != nullptr) s = *scale; if (scale != nullptr) s = *scale;
} }
...@@ -412,12 +425,18 @@ __launch_bounds__(unary_kernel_threads) __global__ ...@@ -412,12 +425,18 @@ __launch_bounds__(unary_kernel_threads) __global__
storer.store(id_x, n); storer.store(id_x, n);
if constexpr (is_fp8<OutputType>::value) { if constexpr (is_fp8<OutputType>::value) {
/* warp tile amax reduce*/ // Reduce amax over block
max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id); if (amax != nullptr) {
max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0) {
static_assert(std::is_same<ComputeType, float>::value);
atomicMaxFloat(amax, max);
}
}
if (threadIdx.x == 0 && amax != nullptr) { // Update scale-inverse
static_assert(std::is_same<ComputeType, float>::value); if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) {
atomicMaxFloat(amax, max); reciprocal<ComputeType>(scale_inv, s);
} }
} }
} }
...@@ -427,8 +446,8 @@ template <int nvec, typename ComputeType, typename Param, ...@@ -427,8 +446,8 @@ template <int nvec, typename ComputeType, typename Param,
ComputeType (*Activation)(const ComputeType, const Param &), typename InputType, ComputeType (*Activation)(const ComputeType, const Param &), typename InputType,
typename OutputType> typename OutputType>
void GatedActivationKernelLauncher(const InputType *input, OutputType *output, const fp32 *scale, void GatedActivationKernelLauncher(const InputType *input, OutputType *output, const fp32 *scale,
fp32 *amax, const size_t m, const size_t n, const Param &p, fp32 *amax, fp32 *scale_inv, const size_t m, const size_t n,
cudaStream_t stream) { const Param &p, cudaStream_t stream) {
if (m != 0 && n != 0) { if (m != 0 && n != 0) {
size_t num_aligned_elements = get_num_aligned_elements(input, n, nvec, sizeof(InputType)); size_t num_aligned_elements = get_num_aligned_elements(input, n, nvec, sizeof(InputType));
constexpr size_t threads = unary_kernel_threads; constexpr size_t threads = unary_kernel_threads;
...@@ -439,18 +458,18 @@ void GatedActivationKernelLauncher(const InputType *input, OutputType *output, c ...@@ -439,18 +458,18 @@ void GatedActivationKernelLauncher(const InputType *input, OutputType *output, c
switch (auto align = CheckAlignment(n, nvec, input, input + n, output)) { switch (auto align = CheckAlignment(n, nvec, input, input + n, output)) {
case Alignment::SAME_ALIGNED: case Alignment::SAME_ALIGNED:
gated_act_kernel<nvec, true, ComputeType, Param, Activation> gated_act_kernel<nvec, true, ComputeType, Param, Activation>
<<<num_blocks, threads, 0, stream>>>(input, output, scale, amax, m, n, p, <<<num_blocks, threads, 0, stream>>>(input, output, scale, amax, scale_inv, m, n, p,
num_aligned_elements); num_aligned_elements);
break; break;
case Alignment::SAME_UNALIGNED: case Alignment::SAME_UNALIGNED:
gated_act_kernel<nvec, false, ComputeType, Param, Activation> gated_act_kernel<nvec, false, ComputeType, Param, Activation>
<<<num_blocks, threads, 0, stream>>>(input, output, scale, amax, m, n, p, <<<num_blocks, threads, 0, stream>>>(input, output, scale, amax, scale_inv, m, n, p,
num_aligned_elements); num_aligned_elements);
break; break;
case Alignment::DIFFERENT: { case Alignment::DIFFERENT: {
// If the pointers are aligned differently we cannot vectorize // If the pointers are aligned differently we cannot vectorize
gated_act_kernel<1, true, ComputeType, Param, Activation> gated_act_kernel<1, true, ComputeType, Param, Activation>
<<<num_blocks, threads, 0, stream>>>(input, output, scale, amax, m, n, p, n); <<<num_blocks, threads, 0, stream>>>(input, output, scale, amax, scale_inv, m, n, p, n);
break; break;
} }
} }
......
...@@ -852,6 +852,11 @@ __device__ __forceinline__ void reciprocal(T *value_inv, const T value) { ...@@ -852,6 +852,11 @@ __device__ __forceinline__ void reciprocal(T *value_inv, const T value) {
*value_inv = 1 / value; *value_inv = 1 / value;
} }
template <>
__device__ __forceinline__ void reciprocal<float>(float *value_inv, const float value) {
*value_inv = __frcp_rn(value);
}
} // namespace transformer_engine } // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_UTILS_CUH_ #endif // TRANSFORMER_ENGINE_COMMON_UTILS_CUH_
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Helper functions for C++ extensions"""
import functools
from typing import Dict, Optional, Tuple, Union
import torch
import transformer_engine_torch as tex
@functools.lru_cache(maxsize=None)
def empty_tensor() -> torch.Tensor:
"""Get tensor with no entries and no data"""
return torch.Tensor()
def canonicalize_fp8_scales(
*,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
fp8_meta: Optional[tex.FP8TensorMeta] = None,
fp8_meta_index: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None] = None,
allow_multiple_offsets: bool = True,
) -> Tuple[Dict[str, torch.Tensor], Dict[str, int]]:
"""Canonicalize FP8 scaling factors (scale, amax, scale-inverse)
If a scaling factor is not provided, try to access it within the
FP8 meta tensors. Returns dict with tensors and dict with tensor
offsets.
"""
# Default: use provided scales with no offsets
scale_offset = 0
amax_offset = 0
scale_inv_offset = 0
# Get scales from FP8 meta tensors if needed
if (fp8_meta is not None) and any(arg is None for arg in (scale, amax, scale_inv)):
if fp8_meta_index is None:
raise ValueError("Provided `fp8_meta` without corresponding `fp8_meta_index`")
fp8_meta_index = int(fp8_meta_index)
if scale is None:
scale = fp8_meta.scale
scale_offset = fp8_meta_index
if amax is None:
amax = fp8_meta.amax_history
amax_offset = fp8_meta_index
if scale_inv is None:
scale_inv = fp8_meta.scale_inv
scale_inv_offset = fp8_meta_index
# Construct empty tensors if needed
if scale is None:
scale = empty_tensor()
scale_offset = 0
if amax is None:
amax = empty_tensor()
amax_offset = 0
if scale_inv is None:
scale_inv = empty_tensor()
scale_inv_offset = 0
# Force offsets to be the same if needed
if not allow_multiple_offsets and not scale_offset == amax_offset == scale_inv_offset:
if scale_offset != 0:
scale = scale[scale_offset]
scale_offset = 0
if amax_offset != 0:
amax = amax[0][amax_offset]
amax_offset = 0
if scale_inv_offset != 0:
scale_inv = scale_inv[scale_inv_offset]
scale_inv_offset = 0
# Pack tensors and offsets into dicts
tensors = dict(scale=scale, amax=amax, scale_inv=scale_inv)
offsets = dict(
scale_offset=scale_offset,
amax_offset=amax_offset,
scale_inv_offset=scale_inv_offset,
)
return tensors, offsets
...@@ -3,192 +3,235 @@ ...@@ -3,192 +3,235 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Python interface for activation extensions""" """Python interface for activation extensions"""
from typing import Union from typing import Optional, Union
import torch import torch
import transformer_engine_torch as tex
import transformer_engine_torch as tex
from ._common import canonicalize_fp8_scales
__all__ = ["gelu", "relu", "reglu", "geglu", "swiglu", "qgelu", "srelu"] __all__ = ["gelu", "relu", "reglu", "geglu", "swiglu", "qgelu", "srelu"]
def gelu( def gelu(
inp: torch.Tensor, inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta, fp8_meta_tensor: Optional[tex.FP8TensorMeta],
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None],
otype: tex.DType, otype: tex.DType,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""GeLU with FP8 output""" """GeLU with FP8 output"""
empty_tensor = torch.Tensor()
if fp8_meta_tensor is not None:
scale = fp8_meta_tensor.scale
amax_history = fp8_meta_tensor.amax_history
scale_inv = fp8_meta_tensor.scale_inv
else:
scale = empty_tensor
amax_history = empty_tensor
scale_inv = empty_tensor
# Get FP8 scaling factors
fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales(
scale=scale,
amax=amax,
scale_inv=scale_inv,
fp8_meta=fp8_meta_tensor,
fp8_meta_index=fp8_tensor,
allow_multiple_offsets=False,
)
# Launch kernel
return torch.ops.tex_ts.gelu_ts( return torch.ops.tex_ts.gelu_ts(
inp, inp,
scale, fp8_scales["scale"],
amax_history, fp8_scales["amax"],
scale_inv, fp8_scales["scale_inv"],
fp8_tensor, fp8_scales_offsets["scale_offset"],
otype, otype,
) )
def relu( def relu(
inp: torch.Tensor, inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta, fp8_meta_tensor: Optional[tex.FP8TensorMeta],
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None],
otype: tex.DType, otype: tex.DType,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""ReLU with FP8 output""" """ReLU with FP8 output"""
empty_tensor = torch.Tensor()
if fp8_meta_tensor is not None: # Get FP8 scaling factors
scale = fp8_meta_tensor.scale fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales(
amax_history = fp8_meta_tensor.amax_history scale=scale,
scale_inv = fp8_meta_tensor.scale_inv amax=amax,
else: scale_inv=scale_inv,
scale = empty_tensor fp8_meta=fp8_meta_tensor,
amax_history = empty_tensor fp8_meta_index=fp8_tensor,
scale_inv = empty_tensor allow_multiple_offsets=False,
)
# Launch kernel
return torch.ops.tex_ts.relu_ts( return torch.ops.tex_ts.relu_ts(
inp, inp,
scale, fp8_scales["scale"],
amax_history, fp8_scales["amax"],
scale_inv, fp8_scales["scale_inv"],
fp8_tensor, fp8_scales_offsets["scale_offset"],
otype, otype,
) )
def geglu( def geglu(
inp: torch.Tensor, inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta, fp8_meta_tensor: Optional[tex.FP8TensorMeta],
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None],
otype: tex.DType, otype: tex.DType,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""GeGLU with FP8 output""" """GeGLU with FP8 output"""
empty_tensor = torch.Tensor()
if fp8_meta_tensor is not None: # Get FP8 scaling factors
scale = fp8_meta_tensor.scale fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales(
amax_history = fp8_meta_tensor.amax_history scale=scale,
scale_inv = fp8_meta_tensor.scale_inv amax=amax,
else: scale_inv=scale_inv,
scale = empty_tensor fp8_meta=fp8_meta_tensor,
amax_history = empty_tensor fp8_meta_index=fp8_tensor,
scale_inv = empty_tensor allow_multiple_offsets=False,
)
# Launch kernel
return torch.ops.tex_ts.geglu_ts( return torch.ops.tex_ts.geglu_ts(
inp, inp,
scale, fp8_scales["scale"],
amax_history, fp8_scales["amax"],
scale_inv, fp8_scales["scale_inv"],
fp8_tensor, fp8_scales_offsets["scale_offset"],
otype, otype,
) )
def reglu( def reglu(
inp: torch.Tensor, inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta, fp8_meta_tensor: Optional[tex.FP8TensorMeta],
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None],
otype: tex.DType, otype: tex.DType,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""ReGLU with FP8 output""" """ReGLU with FP8 output"""
empty_tensor = torch.Tensor()
if fp8_meta_tensor is not None: # Get FP8 scaling factors
scale = fp8_meta_tensor.scale fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales(
amax_history = fp8_meta_tensor.amax_history scale=scale,
scale_inv = fp8_meta_tensor.scale_inv amax=amax,
else: scale_inv=scale_inv,
scale = empty_tensor fp8_meta=fp8_meta_tensor,
amax_history = empty_tensor fp8_meta_index=fp8_tensor,
scale_inv = empty_tensor allow_multiple_offsets=False,
)
# Launch kernel
return torch.ops.tex_ts.reglu_ts( return torch.ops.tex_ts.reglu_ts(
inp, inp,
scale, fp8_scales["scale"],
amax_history, fp8_scales["amax"],
scale_inv, fp8_scales["scale_inv"],
fp8_tensor, fp8_scales_offsets["scale_offset"],
otype, otype,
) )
def swiglu( def swiglu(
inp: torch.Tensor, inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta, fp8_meta_tensor: Optional[tex.FP8TensorMeta],
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None],
otype: tex.DType, otype: tex.DType,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""SwiGLU with FP8 output""" """SwiGLU with FP8 output"""
empty_tensor = torch.Tensor()
if fp8_meta_tensor is not None: # Get FP8 scaling factors
scale = fp8_meta_tensor.scale fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales(
amax_history = fp8_meta_tensor.amax_history scale=scale,
scale_inv = fp8_meta_tensor.scale_inv amax=amax,
else: scale_inv=scale_inv,
scale = empty_tensor fp8_meta=fp8_meta_tensor,
amax_history = empty_tensor fp8_meta_index=fp8_tensor,
scale_inv = empty_tensor allow_multiple_offsets=False,
)
# Launch kernel
return torch.ops.tex_ts.swiglu_ts( return torch.ops.tex_ts.swiglu_ts(
inp, inp,
scale, fp8_scales["scale"],
amax_history, fp8_scales["amax"],
scale_inv, fp8_scales["scale_inv"],
fp8_tensor, fp8_scales_offsets["scale_offset"],
otype, otype,
) )
def qgelu( def qgelu(
inp: torch.Tensor, inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta, fp8_meta_tensor: Optional[tex.FP8TensorMeta],
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None],
otype: tex.DType, otype: tex.DType,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""QuickGELU with FP8 output""" """QuickGELU with FP8 output"""
empty_tensor = torch.Tensor()
if fp8_meta_tensor is not None: # Get FP8 scaling factors
scale = fp8_meta_tensor.scale fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales(
amax_history = fp8_meta_tensor.amax_history scale=scale,
scale_inv = fp8_meta_tensor.scale_inv amax=amax,
else: scale_inv=scale_inv,
scale = empty_tensor fp8_meta=fp8_meta_tensor,
amax_history = empty_tensor fp8_meta_index=fp8_tensor,
scale_inv = empty_tensor allow_multiple_offsets=False,
)
# Launch kernel
return torch.ops.tex_ts.qgelu_ts( return torch.ops.tex_ts.qgelu_ts(
inp, inp,
scale, fp8_scales["scale"],
amax_history, fp8_scales["amax"],
scale_inv, fp8_scales["scale_inv"],
fp8_tensor, fp8_scales_offsets["scale_offset"],
otype, otype,
) )
def srelu( def srelu(
inp: torch.Tensor, inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta, fp8_meta_tensor: Optional[tex.FP8TensorMeta],
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None],
otype: tex.DType, otype: tex.DType,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""ReLU with FP8 output""" """ReLU with FP8 output"""
empty_tensor = torch.Tensor()
if fp8_meta_tensor is not None: # Get FP8 scaling factors
scale = fp8_meta_tensor.scale fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales(
amax_history = fp8_meta_tensor.amax_history scale=scale,
scale_inv = fp8_meta_tensor.scale_inv amax=amax,
else: scale_inv=scale_inv,
scale = empty_tensor fp8_meta=fp8_meta_tensor,
amax_history = empty_tensor fp8_meta_index=fp8_tensor,
scale_inv = empty_tensor allow_multiple_offsets=False,
)
# Launch kernel
return torch.ops.tex_ts.srelu_ts( return torch.ops.tex_ts.srelu_ts(
inp, inp,
scale, fp8_scales["scale"],
amax_history, fp8_scales["amax"],
scale_inv, fp8_scales["scale_inv"],
fp8_tensor, fp8_scales_offsets["scale_offset"],
otype, otype,
) )
...@@ -4,57 +4,91 @@ ...@@ -4,57 +4,91 @@
"""Python interface for cast extensions""" """Python interface for cast extensions"""
from typing import Optional, Union from typing import Optional, Union
import torch import torch
import transformer_engine_torch as tex
import transformer_engine_torch as tex
from ._common import canonicalize_fp8_scales, empty_tensor
__all__ = ["cast_to_fp8", "cast_from_fp8"] __all__ = ["cast_to_fp8", "cast_from_fp8"]
def cast_to_fp8( def cast_to_fp8(
inp: torch.Tensor, inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta, fp8_meta_tensor: Optional[tex.FP8TensorMeta],
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None],
otype: tex.DType, otype: tex.DType,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
) -> Optional[torch.Tensor]: scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Cast input to FP8""" """Cast input to FP8"""
if out is not None: # Get FP8 scaling factors
if inp.nelement() > 0: fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales(
torch.ops.tex_ts.cast_to_fp8_noalloc_ts( scale=scale,
inp, amax=amax,
fp8_meta_tensor.scale, scale_inv=scale_inv,
out, fp8_meta=fp8_meta_tensor,
fp8_meta_tensor.amax_history, fp8_meta_index=fp8_tensor,
fp8_meta_tensor.scale_inv, allow_multiple_offsets=False,
fp8_tensor,
otype,
)
return None
return torch.ops.tex_ts.cast_to_fp8_ts(
inp,
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_tensor,
otype,
) )
# Launch FP8 cast kernel
if inp.nelement() == 0:
if out is None:
out = torch.empty_like(inp, dtype=torch.uint8)
elif out is None:
out = torch.ops.tex_ts.cast_to_fp8_ts(
inp,
fp8_scales["scale"],
fp8_scales["amax"],
fp8_scales["scale_inv"],
fp8_scales_offsets["scale_offset"],
otype,
)
else:
torch.ops.tex_ts.cast_to_fp8_noalloc_ts(
inp,
fp8_scales["scale"],
out,
fp8_scales["amax"],
fp8_scales["scale_inv"],
fp8_scales_offsets["scale_offset"],
otype,
)
return out
def cast_from_fp8( def cast_from_fp8(
inp: torch.Tensor, inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta, fp8_meta_tensor: Optional[tex.FP8TensorMeta],
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None],
itype: tex.DType, itype: tex.DType,
otype: tex.DType, otype: tex.DType,
scale_inv: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Cast input from FP8""" """Cast input from FP8"""
# Get scaling factors from FP8 meta tensors if needed
scale_inv_offset = 0
if (fp8_meta_tensor is not None) and (scale_inv is None):
if fp8_tensor is None:
raise ValueError("Provided `fp8_meta_tensor` without corresponding `fp8_tensor`")
scale_inv = fp8_meta_tensor.scale_inv
scale_inv_offset = int(fp8_tensor)
# Construct empty tensors if needed
if scale_inv is None:
scale_inv = empty_tensor()
scale_inv_offset = 0
# Launch FP8 cast kernel
return torch.ops.tex_ts.cast_from_fp8_ts( return torch.ops.tex_ts.cast_from_fp8_ts(
inp, inp,
fp8_meta_tensor.scale_inv, scale_inv,
fp8_tensor, scale_inv_offset,
itype, itype,
otype, otype,
) )
...@@ -4,8 +4,11 @@ ...@@ -4,8 +4,11 @@
"""Python interface for normalization extensions""" """Python interface for normalization extensions"""
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from ._common import canonicalize_fp8_scales
__all__ = [ __all__ = [
...@@ -23,46 +26,55 @@ def layernorm_fwd_fp8( ...@@ -23,46 +26,55 @@ def layernorm_fwd_fp8(
weight: torch.Tensor, weight: torch.Tensor,
bias: torch.Tensor, bias: torch.Tensor,
eps: float, eps: float,
fp8_meta_tensor: tex.FP8TensorMeta, fp8_meta_tensor: Optional[tex.FP8TensorMeta],
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None],
otype: tex.DType, otype: tex.DType,
sm_margin: int, sm_margin: int,
zero_centered_gamma: bool, zero_centered_gamma: bool,
ln_out: Optional[torch.Tensor] = None, ln_out: Optional[torch.Tensor] = None,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""LayerNorm with FP8 output""" """LayerNorm with FP8 output"""
# Get FP8 scaling factors
fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales(
scale=scale,
amax=amax,
scale_inv=scale_inv,
fp8_meta=fp8_meta_tensor,
fp8_meta_index=fp8_tensor,
)
# Launch kernel
if ln_out is not None: if ln_out is not None:
return tex.layernorm_fwd_fp8_noalloc( return tex.layernorm_fwd_fp8_noalloc(
inp, inp,
weight, weight,
bias, bias,
eps, eps,
fp8_meta_tensor.scale, fp8_scales["scale"],
ln_out, ln_out,
fp8_meta_tensor.amax_history, fp8_scales["amax"],
fp8_meta_tensor.scale_inv, fp8_scales["scale_inv"],
otype, otype,
sm_margin, sm_margin,
zero_centered_gamma, zero_centered_gamma,
scale_offset=int(fp8_tensor), **fp8_scales_offsets,
amax_offset=int(fp8_tensor),
scale_inv_offset=int(fp8_tensor),
) )
return tex.layernorm_fwd_fp8( return tex.layernorm_fwd_fp8(
inp, inp,
weight, weight,
bias, bias,
eps, eps,
fp8_meta_tensor.scale, fp8_scales["scale"],
fp8_meta_tensor.amax_history, fp8_scales["amax"],
fp8_meta_tensor.scale_inv, fp8_scales["scale_inv"],
otype, otype,
sm_margin, sm_margin,
zero_centered_gamma, zero_centered_gamma,
scale_offset=int(fp8_tensor), **fp8_scales_offsets,
amax_offset=int(fp8_tensor),
scale_inv_offset=int(fp8_tensor),
) )
...@@ -71,26 +83,41 @@ def layernorm_fwd_fp8_inf( ...@@ -71,26 +83,41 @@ def layernorm_fwd_fp8_inf(
weight: torch.Tensor, weight: torch.Tensor,
bias: torch.Tensor, bias: torch.Tensor,
eps: float, eps: float,
fp8_meta_tensor: tex.FP8TensorMeta, fp8_meta_tensor: Optional[tex.FP8TensorMeta],
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None],
otype: tex.DType, otype: tex.DType,
sm_margin: int, sm_margin: int,
zero_centered_gamma, zero_centered_gamma,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""LayerNorm with FP8 output. """LayerNorm with FP8 output.
This version of layernorm_fwd_fp8 is specialized for inference, and returns This version of layernorm_fwd_fp8 is specialized for inference, and returns
only the normalized output. only the normalized output.
""" """
# Get FP8 scaling factors
fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales(
scale=scale,
amax=amax,
scale_inv=scale_inv,
fp8_meta=fp8_meta_tensor,
fp8_meta_index=fp8_tensor,
allow_multiple_offsets=False,
)
# Launch kernel
ret = torch.ops.tex_ts.layernorm_fwd_fp8_inf_ts( ret = torch.ops.tex_ts.layernorm_fwd_fp8_inf_ts(
inp, inp,
weight, weight,
bias, bias,
eps, eps,
fp8_meta_tensor.scale, fp8_scales["scale"],
fp8_meta_tensor.amax_history, fp8_scales["amax"],
fp8_meta_tensor.scale_inv, fp8_scales["scale_inv"],
fp8_tensor, fp8_scales_offsets["scale_offset"],
otype, otype,
sm_margin, sm_margin,
zero_centered_gamma, zero_centered_gamma,
...@@ -121,44 +148,53 @@ def rmsnorm_fwd_fp8( ...@@ -121,44 +148,53 @@ def rmsnorm_fwd_fp8(
inp: torch.Tensor, inp: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
eps: float, eps: float,
fp8_meta_tensor: tex.FP8TensorMeta, fp8_meta_tensor: Optional[tex.FP8TensorMeta],
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None],
otype: tex.DType, otype: tex.DType,
sm_margin: int, sm_margin: int,
zero_centered_gamma: bool, zero_centered_gamma: bool,
rmsnorm_out: Optional[torch.Tensor] = None, rmsnorm_out: Optional[torch.Tensor] = None,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""RMSNorm with FP8 output""" """RMSNorm with FP8 output"""
# Get FP8 scaling factors
fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales(
scale=scale,
amax=amax,
scale_inv=scale_inv,
fp8_meta=fp8_meta_tensor,
fp8_meta_index=fp8_tensor,
)
# Launch kernel
if rmsnorm_out is not None: if rmsnorm_out is not None:
return tex.rmsnorm_fwd_fp8_noalloc( return tex.rmsnorm_fwd_fp8_noalloc(
inp, inp,
weight, weight,
eps, eps,
fp8_meta_tensor.scale, fp8_scales["scale"],
rmsnorm_out, rmsnorm_out,
fp8_meta_tensor.amax_history, fp8_scales["amax"],
fp8_meta_tensor.scale_inv, fp8_scales["scale_inv"],
otype, otype,
sm_margin, sm_margin,
zero_centered_gamma, zero_centered_gamma,
scale_offset=int(fp8_tensor), **fp8_scales_offsets,
amax_offset=int(fp8_tensor),
scale_inv_offset=int(fp8_tensor),
) )
return tex.rmsnorm_fwd_fp8( return tex.rmsnorm_fwd_fp8(
inp, inp,
weight, weight,
eps, eps,
fp8_meta_tensor.scale, fp8_scales["scale"],
fp8_meta_tensor.amax_history, fp8_scales["amax"],
fp8_meta_tensor.scale_inv, fp8_scales["scale_inv"],
otype, otype,
sm_margin, sm_margin,
zero_centered_gamma, zero_centered_gamma,
scale_offset=int(fp8_tensor), **fp8_scales_offsets,
amax_offset=int(fp8_tensor),
scale_inv_offset=int(fp8_tensor),
) )
...@@ -166,25 +202,40 @@ def rmsnorm_fwd_fp8_inf( ...@@ -166,25 +202,40 @@ def rmsnorm_fwd_fp8_inf(
inp: torch.Tensor, inp: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
eps: float, eps: float,
fp8_meta_tensor: tex.FP8TensorMeta, fp8_meta_tensor: Optional[tex.FP8TensorMeta],
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None],
otype: tex.DType, otype: tex.DType,
sm_margin: int, sm_margin: int,
zero_centered_gamma, zero_centered_gamma,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""RMSNorm with FP8 output. """RMSNorm with FP8 output.
This version of rmsnorm_fwd_fp8 is specialized for inference, and returns This version of rmsnorm_fwd_fp8 is specialized for inference, and returns
only the normalized output. only the normalized output.
""" """
# Get FP8 scaling factors
fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales(
scale=scale,
amax=amax,
scale_inv=scale_inv,
fp8_meta=fp8_meta_tensor,
fp8_meta_index=fp8_tensor,
allow_multiple_offsets=False,
)
# Launch kernel
ret = torch.ops.tex_ts.rmsnorm_fwd_fp8_inf_ts( ret = torch.ops.tex_ts.rmsnorm_fwd_fp8_inf_ts(
inp, inp,
weight, weight,
eps, eps,
fp8_meta_tensor.scale, fp8_scales["scale"],
fp8_meta_tensor.amax_history, fp8_scales["amax"],
fp8_meta_tensor.scale_inv, fp8_scales["scale_inv"],
fp8_tensor, fp8_scales_offsets["scale_offset"],
otype, otype,
sm_margin, sm_margin,
zero_centered_gamma, zero_centered_gamma,
......
...@@ -4,9 +4,12 @@ ...@@ -4,9 +4,12 @@
"""Python interface for transpose extensions""" """Python interface for transpose extensions"""
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from ..constants import TE_DType from ..constants import TE_DType
from ._common import canonicalize_fp8_scales, empty_tensor
__all__ = [ __all__ = [
...@@ -20,83 +23,115 @@ __all__ = [ ...@@ -20,83 +23,115 @@ __all__ = [
def fp8_cast_transpose_fused( def fp8_cast_transpose_fused(
inp: torch.Tensor, inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta, fp8_meta_tensor: Optional[tex.FP8TensorMeta],
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None],
otype: tex.DType, otype: tex.DType,
cast_out: Optional[torch.Tensor] = None, cast_out: Optional[torch.Tensor] = None,
transpose_out: Optional[torch.Tensor] = None, transpose_out: Optional[torch.Tensor] = None,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
noop_flag: Optional[torch.Tensor] = None, noop_flag: Optional[torch.Tensor] = None,
) -> Union[Tuple[torch.Tensor, torch.Tensor], None]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Cast + Transpose with FP8 output""" """Cast + Transpose with FP8 output"""
return_outputs = False # Allocate outputs if needed
if transpose_out is None: if transpose_out is None:
transpose_out = torch.empty(inp.shape[1], inp.shape[0], device="cuda", dtype=torch.uint8) transpose_out = torch.empty(inp.shape[1], inp.shape[0], device="cuda", dtype=torch.uint8)
return_outputs = True
if cast_out is None: if cast_out is None:
cast_out = torch.empty_like(inp, dtype=torch.uint8) cast_out = torch.empty_like(inp, dtype=torch.uint8)
return_outputs = True
# Get FP8 scaling factors
fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales(
scale=scale,
amax=amax,
scale_inv=scale_inv,
fp8_meta=fp8_meta_tensor,
fp8_meta_index=fp8_tensor,
)
# Construct no-op flag if needed
if noop_flag is None: if noop_flag is None:
noop_flag = torch.Tensor() noop_flag = empty_tensor()
# Launch kernel if needed
if inp.nelement() > 0: if inp.nelement() > 0:
tex.fused_cast_transpose_noop( tex.fused_cast_transpose_noop(
inp, inp,
noop_flag, noop_flag,
fp8_meta_tensor.scale, fp8_scales["scale"],
fp8_meta_tensor.amax_history, fp8_scales["amax"],
fp8_meta_tensor.scale_inv, fp8_scales["scale_inv"],
cast_out, cast_out,
transpose_out, transpose_out,
otype, otype,
scale_offset=int(fp8_tensor), **fp8_scales_offsets,
amax_offset=int(fp8_tensor),
scale_inv_offset=int(fp8_tensor),
) )
if return_outputs: return cast_out, transpose_out
return cast_out, transpose_out
return None
def fp8_cast_transpose_bgrad_fused( def fp8_cast_transpose_bgrad_fused(
inp: torch.Tensor, inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta, fp8_meta_tensor: Optional[tex.FP8TensorMeta],
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None],
otype: tex.DType, otype: tex.DType,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Cast + Transpose + BGRAD with FP8 output""" """Cast + Transpose + BGRAD with FP8 output"""
# Get FP8 scaling factors
fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales(
scale=scale,
amax=amax,
scale_inv=scale_inv,
fp8_meta=fp8_meta_tensor,
fp8_meta_index=fp8_tensor,
)
# Launch kernel
return tex.fused_cast_transpose_bgrad( return tex.fused_cast_transpose_bgrad(
inp, inp,
fp8_meta_tensor.scale, fp8_scales["scale"],
fp8_meta_tensor.amax_history, fp8_scales["amax"],
fp8_meta_tensor.scale_inv, fp8_scales["scale_inv"],
otype, otype,
scale_offset=int(fp8_tensor), **fp8_scales_offsets,
amax_offset=int(fp8_tensor),
scale_inv_offset=int(fp8_tensor),
) )
def fp8_transpose_bgrad_fused( def fp8_transpose_bgrad_fused(
inp: torch.Tensor, inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta, fp8_meta_tensor: Optional[tex.FP8TensorMeta],
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None],
otype: tex.DType, otype: tex.DType,
grad_bias_type: torch.dtype, grad_bias_type: torch.dtype,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Transpose + BGRAD with FP8 output""" """Transpose + BGRAD with FP8 output"""
# Get FP8 scaling factors
fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales(
scale=scale,
amax=amax,
scale_inv=scale_inv,
fp8_meta=fp8_meta_tensor,
fp8_meta_index=fp8_tensor,
)
# Launch kernel
return tex.fused_fp8_transpose_bgrad( return tex.fused_fp8_transpose_bgrad(
inp, inp,
fp8_meta_tensor.scale, fp8_scales["scale"],
fp8_meta_tensor.amax_history, fp8_scales["amax"],
fp8_meta_tensor.scale_inv, fp8_scales["scale_inv"],
otype, otype,
TE_DType[grad_bias_type], TE_DType[grad_bias_type],
scale_offset=int(fp8_tensor), **fp8_scales_offsets,
amax_offset=int(fp8_tensor),
scale_inv_offset=int(fp8_tensor),
) )
...@@ -106,18 +141,30 @@ def fp8_cast_transpose_bgrad_dgelu_fused( ...@@ -106,18 +141,30 @@ def fp8_cast_transpose_bgrad_dgelu_fused(
fp8_meta_tensor: tex.FP8TensorMeta, fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType, otype: tex.DType,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Cast + Transpose + BGRAD + DGELU with FP8 output""" """Cast + Transpose + BGRAD + DGELU with FP8 output"""
# Get FP8 scaling factors
fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales(
scale=scale,
amax=amax,
scale_inv=scale_inv,
fp8_meta=fp8_meta_tensor,
fp8_meta_index=fp8_tensor,
)
# Launch kernel
return tex.fused_cast_transpose_bgrad_dgelu( return tex.fused_cast_transpose_bgrad_dgelu(
grad_output, grad_output,
gelu_input, gelu_input,
fp8_meta_tensor.scale, fp8_scales["scale"],
fp8_meta_tensor.amax_history, fp8_scales["amax"],
fp8_meta_tensor.scale_inv, fp8_scales["scale_inv"],
otype, otype,
scale_offset=int(fp8_tensor), **fp8_scales_offsets,
amax_offset=int(fp8_tensor),
scale_inv_offset=int(fp8_tensor),
) )
......
...@@ -117,13 +117,6 @@ class _ToFloat8Func(torch.autograd.Function): ...@@ -117,13 +117,6 @@ class _ToFloat8Func(torch.autograd.Function):
scale_inv: Optional[torch.Tensor] = None, scale_inv: Optional[torch.Tensor] = None,
) -> Float8Tensor: ) -> Float8Tensor:
# Manually compute scale-inverse if needed
if scale is not None and scale_inv is None:
if isinstance(scale, torch.Tensor):
scale_inv = scale.reciprocal()
else:
scale_inv = 1 / scale
# Extract data from FP8 meta tensors if provided # Extract data from FP8 meta tensors if provided
if fp8_meta is not None: if fp8_meta is not None:
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
...@@ -138,9 +131,6 @@ class _ToFloat8Func(torch.autograd.Function): ...@@ -138,9 +131,6 @@ class _ToFloat8Func(torch.autograd.Function):
scale = fp8_meta[fp8_meta_key].scale[fp8_meta_index] scale = fp8_meta[fp8_meta_key].scale[fp8_meta_index]
if amax is None: if amax is None:
amax = fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index] amax = fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index]
if scale_inv is None:
scale_inv = fp8_meta[fp8_meta_key].scale_inv[fp8_meta_index]
scale_inv = scale_inv.detach().view(1).clone()
# Check input tensor # Check input tensor
tensor = tensor.contiguous().cuda().detach() tensor = tensor.contiguous().cuda().detach()
...@@ -163,8 +153,9 @@ class _ToFloat8Func(torch.autograd.Function): ...@@ -163,8 +153,9 @@ class _ToFloat8Func(torch.autograd.Function):
# Check scale-inverse # Check scale-inverse
if scale_inv is None: if scale_inv is None:
scale_inv = scale.reciprocal() scale_inv = torch.empty([1], dtype=torch.float32, device=tensor.device)
scale_inv = scale_inv.to(device=tensor.device, dtype=torch.float32) else:
scale_inv = scale_inv.to(device=tensor.device, dtype=torch.float32)
# Check amax # Check amax
if amax is None: if amax is None:
...@@ -737,19 +728,9 @@ class Float8Tensor(torch.Tensor): ...@@ -737,19 +728,9 @@ class Float8Tensor(torch.Tensor):
self._fp8_dtype, self._fp8_dtype,
cast_out=data, cast_out=data,
transpose_out=transpose, transpose_out=transpose,
scale_inv=self._scale_inv,
noop_flag=noop_flag, noop_flag=noop_flag,
) )
scale = fp8_meta.scale[fp8_meta_index : fp8_meta_index + 1]
scale_inv = self._scale_inv
if noop_flag is None:
torch.reciprocal(scale, out=scale_inv)
else:
torch.where(
noop_flag.bool(),
scale_inv,
scale.reciprocal(),
out=scale_inv,
)
self._transpose_invalid = False self._transpose_invalid = False
@torch.no_grad() @torch.no_grad()
...@@ -853,7 +834,6 @@ class Float8Tensor(torch.Tensor): ...@@ -853,7 +834,6 @@ class Float8Tensor(torch.Tensor):
fp8_meta_index = dst._fp8_meta_index fp8_meta_index = dst._fp8_meta_index
scale = dst._fp8_meta[fp8_meta_key].scale[fp8_meta_index] scale = dst._fp8_meta[fp8_meta_key].scale[fp8_meta_index]
amax = dst._fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index] amax = dst._fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index]
dst._scale_inv.copy_(scale.detach().reciprocal())
# Cast to FP8 # Cast to FP8
if not dst._data.is_contiguous(): if not dst._data.is_contiguous():
......
...@@ -52,6 +52,9 @@ def _apply_normalization( ...@@ -52,6 +52,9 @@ def _apply_normalization(
fwd_ln_sm_margin: int, fwd_ln_sm_margin: int,
zero_centered_gamma: bool, zero_centered_gamma: bool,
is_grad_enabled: bool, is_grad_enabled: bool,
fp8_scale: Optional[torch.Tensor] = None,
fp8_amax: Optional[torch.Tensor] = None,
fp8_scale_inv: Optional[torch.Tensor] = None,
): ):
normalization_func = _get_normalization_func(normalization, fp8_out, is_grad_enabled, True) normalization_func = _get_normalization_func(normalization, fp8_out, is_grad_enabled, True)
...@@ -70,6 +73,9 @@ def _apply_normalization( ...@@ -70,6 +73,9 @@ def _apply_normalization(
fp8_dtype_forward, fp8_dtype_forward,
fwd_ln_sm_margin, fwd_ln_sm_margin,
zero_centered_gamma, zero_centered_gamma,
scale=fp8_scale,
amax=fp8_amax,
scale_inv=fp8_scale_inv,
**output_kwarg, **output_kwarg,
) )
else: else:
...@@ -82,6 +88,9 @@ def _apply_normalization( ...@@ -82,6 +88,9 @@ def _apply_normalization(
fp8_dtype_forward, fp8_dtype_forward,
fwd_ln_sm_margin, fwd_ln_sm_margin,
zero_centered_gamma, zero_centered_gamma,
scale=fp8_scale,
amax=fp8_amax,
scale_inv=fp8_scale_inv,
), ),
None, None,
None, None,
......
...@@ -46,6 +46,7 @@ from ..jit import no_torch_dynamo ...@@ -46,6 +46,7 @@ from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing from ..graph import is_graph_capturing
from ._common import _apply_normalization, _noop_cat from ._common import _apply_normalization, _noop_cat
from ..float8_tensor import Float8Tensor from ..float8_tensor import Float8Tensor
from ..export import is_in_onnx_export_mode
__all__ = ["LayerNormLinear"] __all__ = ["LayerNormLinear"]
...@@ -126,8 +127,13 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -126,8 +127,13 @@ class _LayerNormLinear(torch.autograd.Function):
inputmat, dtype=ln_out_dtype, memory_format=torch.contiguous_format inputmat, dtype=ln_out_dtype, memory_format=torch.contiguous_format
) )
# Objects for FP8 cast
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
ln_out_scale_inv = None
if fp8:
ln_out_scale_inv = torch.empty([1], dtype=torch.float32, device=inputmat.device)
# Launch normalization kernel
ln_out, mu, rsigma = _apply_normalization( ln_out, mu, rsigma = _apply_normalization(
inputmat, inputmat,
ln_out, ln_out,
...@@ -140,6 +146,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -140,6 +146,7 @@ class _LayerNormLinear(torch.autograd.Function):
fwd_ln_sm_margin, fwd_ln_sm_margin,
zero_centered_gamma, zero_centered_gamma,
is_grad_enabled, is_grad_enabled,
fp8_scale_inv=ln_out_scale_inv,
) )
# Column Parallel Linear # Column Parallel Linear
...@@ -172,6 +179,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -172,6 +179,7 @@ class _LayerNormLinear(torch.autograd.Function):
tex.FP8FwdTensors.GEMM1_INPUT, tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward, fp8_dtype_forward,
out=ln_out_fp8, out=ln_out_fp8,
scale_inv=ln_out_scale_inv,
) )
ln_out = torch.empty_like(ln_out_fp8) ln_out = torch.empty_like(ln_out_fp8)
else: else:
...@@ -180,6 +188,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -180,6 +188,7 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT, tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward, fp8_dtype_forward,
scale_inv=ln_out_scale_inv,
) )
if ln_out_gathered: if ln_out_gathered:
rank = torch.distributed.get_rank(tp_group) rank = torch.distributed.get_rank(tp_group)
...@@ -199,6 +208,18 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -199,6 +208,18 @@ class _LayerNormLinear(torch.autograd.Function):
assert isinstance(weight_fp8, Float8Tensor) assert isinstance(weight_fp8, Float8Tensor)
# Hack for ONNX export
# Note: ONNX models are represented as a graph of tensor
# operations, so the in-place scale-inv update doesn't fit
# very well. We work around this by making it look like
# the scale-inv tensor is initialized with a copy.
# Note: ONNX export expects FP8 scales can be represented
# with constant ops. However, copying into a buffer
# involves an expand op for array broadcasting. We work
# around this by filling the buffer instead.
if is_in_onnx_export_mode():
ln_out_scale_inv.fill_(ln_out_scale_inv.item())
if fp8_meta["recipe"].fp8_mha: if fp8_meta["recipe"].fp8_mha:
out_index, meta_tensor, output_te_dtype, output_dtype = ( out_index, meta_tensor, output_te_dtype, output_dtype = (
tex.FP8FwdTensors.GEMM1_OUTPUT, tex.FP8FwdTensors.GEMM1_OUTPUT,
...@@ -219,8 +240,8 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -219,8 +240,8 @@ class _LayerNormLinear(torch.autograd.Function):
0, 0,
weight_fp8._fp8_dtype, weight_fp8._fp8_dtype,
ln_out_total, ln_out_total,
fp8_meta["scaling_fwd"].scale_inv, ln_out_scale_inv,
tex.FP8FwdTensors.GEMM1_INPUT, 0,
fp8_dtype_forward, fp8_dtype_forward,
output_dtype, output_dtype,
get_workspace(), get_workspace(),
...@@ -306,7 +327,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -306,7 +327,7 @@ class _LayerNormLinear(torch.autograd.Function):
weight_fp8, weight_fp8,
weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None, weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None,
ln_out if weight.requires_grad else None, ln_out if weight.requires_grad else None,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, ln_out_scale_inv,
) )
ctx.activation_dtype = activation_dtype ctx.activation_dtype = activation_dtype
...@@ -377,7 +398,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -377,7 +398,7 @@ class _LayerNormLinear(torch.autograd.Function):
weight_fp8, weight_fp8,
main_grad, main_grad,
ln_out, ln_out,
fwd_scale_inverses, ln_out_scale_inv,
) = ctx.saved_tensors ) = ctx.saved_tensors
# Gather intermediate/activation tensors if needed # Gather intermediate/activation tensors if needed
...@@ -570,8 +591,8 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -570,8 +591,8 @@ class _LayerNormLinear(torch.autograd.Function):
ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward) ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward)
wgrad, _ = tex.fp8_gemm( wgrad, _ = tex.fp8_gemm(
ln_out_total_t, ln_out_total_t,
fwd_scale_inverses, ln_out_scale_inv,
tex.FP8FwdTensors.GEMM1_INPUT, 0,
fp8_dtype_forward, fp8_dtype_forward,
( (
grad_output_t._data grad_output_t._data
...@@ -596,8 +617,8 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -596,8 +617,8 @@ class _LayerNormLinear(torch.autograd.Function):
else: else:
ln_out_total_c = torch.ops.tex_ts.cast_from_fp8_ts( ln_out_total_c = torch.ops.tex_ts.cast_from_fp8_ts(
ln_out_total, ln_out_total,
fwd_scale_inverses, ln_out_scale_inv,
tex.FP8FwdTensors.GEMM1_INPUT, 0,
fp8_dtype_forward, fp8_dtype_forward,
TE_DType[ctx.activation_dtype], TE_DType[ctx.activation_dtype],
) )
......
...@@ -48,6 +48,7 @@ from ..constants import GemmParallelModes, dist_group_type ...@@ -48,6 +48,7 @@ from ..constants import GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing from ..graph import is_graph_capturing
from ..float8_tensor import Float8Tensor from ..float8_tensor import Float8Tensor
from ..export import is_in_onnx_export_mode
__all__ = ["Linear"] __all__ = ["Linear"]
...@@ -103,10 +104,12 @@ class _Linear(torch.autograd.Function): ...@@ -103,10 +104,12 @@ class _Linear(torch.autograd.Function):
inputmat = cast_if_needed(inputmat, activation_dtype) inputmat = cast_if_needed(inputmat, activation_dtype)
inputmat_t = None inputmat_t = None
inputmat_no_fp8 = inputmat inputmat_no_fp8 = inputmat
inputmat_scale_inv = None
if fp8: if fp8:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if isinstance(inputmat, Float8Tensor): if isinstance(inputmat, Float8Tensor):
inputmat_scale_inv = inputmat._scale_inv
if ( if (
not fp8_meta["recipe"].override_linear_precision.wgrad not fp8_meta["recipe"].override_linear_precision.wgrad
and is_grad_enabled and is_grad_enabled
...@@ -116,6 +119,7 @@ class _Linear(torch.autograd.Function): ...@@ -116,6 +119,7 @@ class _Linear(torch.autograd.Function):
# FP8 input for forward, FP8 input transpose for backward wgrad # FP8 input for forward, FP8 input transpose for backward wgrad
inputmat_t = inputmat.transpose_2d() inputmat_t = inputmat.transpose_2d()
else: else:
inputmat_scale_inv = torch.empty([1], dtype=torch.float32, device=inputmat.device)
if ( if (
not fp8_meta["recipe"].override_linear_precision.wgrad not fp8_meta["recipe"].override_linear_precision.wgrad
and is_grad_enabled and is_grad_enabled
...@@ -128,6 +132,7 @@ class _Linear(torch.autograd.Function): ...@@ -128,6 +132,7 @@ class _Linear(torch.autograd.Function):
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT, tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward, fp8_dtype_forward,
scale_inv=inputmat_scale_inv,
) )
else: else:
# FP8 input for forward # FP8 input for forward
...@@ -136,8 +141,21 @@ class _Linear(torch.autograd.Function): ...@@ -136,8 +141,21 @@ class _Linear(torch.autograd.Function):
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT, tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward, fp8_dtype_forward,
scale_inv=inputmat_scale_inv,
) )
# Hack for ONNX export
# Note: ONNX models are represented as a graph of tensor
# operations, so the in-place scale-inv update doesn't fit
# very well. We work around this by making it look like
# the scale-inv tensor is initialized with a copy.
# Note: ONNX export expects FP8 scales can be represented
# with constant ops. However, copying into a buffer
# involves an expand op for array broadcasting. We work
# around this by filling the buffer instead.
if is_in_onnx_export_mode():
inputmat_scale_inv.fill_(inputmat_scale_inv.item())
# Column Parallel Linear # Column Parallel Linear
if parallel_mode == "column" and sequence_parallel: if parallel_mode == "column" and sequence_parallel:
inputmat_total, _ = gather_along_first_dim(inputmat, tp_group) inputmat_total, _ = gather_along_first_dim(inputmat, tp_group)
...@@ -206,8 +224,8 @@ class _Linear(torch.autograd.Function): ...@@ -206,8 +224,8 @@ class _Linear(torch.autograd.Function):
if isinstance(inputmat_total, Float8Tensor) if isinstance(inputmat_total, Float8Tensor)
else inputmat_total else inputmat_total
), ),
fp8_meta["scaling_fwd"].scale_inv, inputmat_scale_inv,
tex.FP8FwdTensors.GEMM1_INPUT, 0,
fp8_dtype_forward, fp8_dtype_forward,
proj_out_pttype, proj_out_pttype,
get_workspace(), get_workspace(),
...@@ -312,10 +330,10 @@ class _Linear(torch.autograd.Function): ...@@ -312,10 +330,10 @@ class _Linear(torch.autograd.Function):
ctx.save_for_backward( ctx.save_for_backward(
saved_inputmat, saved_inputmat,
saved_inputmat_t, saved_inputmat_t,
inputmat_scale_inv,
weight, weight,
weight_fp8, weight_fp8,
weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None, weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None,
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
) )
ctx.activation_dtype = activation_dtype ctx.activation_dtype = activation_dtype
...@@ -364,10 +382,10 @@ class _Linear(torch.autograd.Function): ...@@ -364,10 +382,10 @@ class _Linear(torch.autograd.Function):
( (
inputmat, inputmat,
inputmat_t, inputmat_t,
inputmat_scale_inv,
weight, weight,
weight_fp8, weight_fp8,
main_grad, main_grad,
fwd_scale_inverses,
) = ctx.saved_tensors ) = ctx.saved_tensors
# Gather intermediate/activation tensors if needed # Gather intermediate/activation tensors if needed
...@@ -520,8 +538,8 @@ class _Linear(torch.autograd.Function): ...@@ -520,8 +538,8 @@ class _Linear(torch.autograd.Function):
if isinstance(inputmat_t_total, Float8Tensor) if isinstance(inputmat_t_total, Float8Tensor)
else inputmat_t_total else inputmat_t_total
), ),
fwd_scale_inverses, inputmat_scale_inv,
tex.FP8FwdTensors.GEMM1_INPUT, 0,
fp8_dtype_forward, fp8_dtype_forward,
grad_output_t, grad_output_t,
ctx.fp8_meta["scaling_bwd"].scale_inv, ctx.fp8_meta["scaling_bwd"].scale_inv,
......
...@@ -74,7 +74,7 @@ def is_dtype_bf16(t): ...@@ -74,7 +74,7 @@ def is_dtype_bf16(t):
return t.type().scalarType() == "BFloat16" return t.type().scalarType() == "BFloat16"
def quantize(g, inputs, scale_inv, fp8_tensor): def quantize(g, inputs, scale, fp8_tensor):
"""Helper Function for Quantization""" """Helper Function for Quantization"""
output_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inputs) output_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inputs)
...@@ -83,7 +83,7 @@ def quantize(g, inputs, scale_inv, fp8_tensor): ...@@ -83,7 +83,7 @@ def quantize(g, inputs, scale_inv, fp8_tensor):
if not is_dtype_fp32(inputs): if not is_dtype_fp32(inputs):
inputs = g.op("Cast", inputs, to_i=_C_onnx.TensorProtoDataType.FLOAT) inputs = g.op("Cast", inputs, to_i=_C_onnx.TensorProtoDataType.FLOAT)
scale = g.op("Constant", value_t=torch.tensor(scale_inv[fp8_tensor])) scale = g.op("Constant", value_t=torch.tensor(1 / scale[fp8_tensor]))
q_op = g.op(make_op_name("TRT_FP8QuantizeLinear"), inputs, scale).setType( q_op = g.op(make_op_name("TRT_FP8QuantizeLinear"), inputs, scale).setType(
inputs.type().with_dtype(torch.uint8).with_sizes(output_shape) inputs.type().with_dtype(torch.uint8).with_sizes(output_shape)
) )
...@@ -124,18 +124,18 @@ def compute_in_fp32(g, inp, subgraph, *args, **kwargs): ...@@ -124,18 +124,18 @@ def compute_in_fp32(g, inp, subgraph, *args, **kwargs):
return sg_out return sg_out
@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i") @symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i")
def onnx_cast_to_fp8(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): def onnx_cast_to_fp8(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for cast_to_fp8""" """ONNX graph for cast_to_fp8"""
# pylint: disable=unused-argument # pylint: disable=unused-argument
return quantize(g, inputs, scale_inv, fp8_tensor) return quantize(g, inputs, scale, fp8_tensor)
@symbolic_helper.parse_args("v", "v", "v", "v", "fs", "i", "i") @symbolic_helper.parse_args("v", "fs", "v", "v", "v", "i", "i")
def onnx_cast_to_fp8_noalloc(g, inputs, scale, output, amax, scale_inv, fp8_tensor, otype): def onnx_cast_to_fp8_noalloc(g, inputs, scale, output, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for cast_to_fp8_noalloc""" """ONNX graph for cast_to_fp8_noalloc"""
# pylint: disable=unused-argument # pylint: disable=unused-argument
return quantize(g, inputs, scale_inv, fp8_tensor) return quantize(g, inputs, scale, fp8_tensor)
@symbolic_helper.parse_args("v", "fs", "i", "i", "i") @symbolic_helper.parse_args("v", "fs", "i", "i", "i")
...@@ -145,25 +145,25 @@ def onnx_cast_from_fp8(g, inputs, scale_inv, fp8_tensor, itype, otype): ...@@ -145,25 +145,25 @@ def onnx_cast_from_fp8(g, inputs, scale_inv, fp8_tensor, itype, otype):
return dequantize(g, inputs, scale_inv, fp8_tensor, otype) return dequantize(g, inputs, scale_inv, fp8_tensor, otype)
@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i") @symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i")
def onnx_fp8_gelu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): def onnx_fp8_gelu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for fp8_gelu""" """ONNX graph for fp8_gelu"""
# pylint: disable=unused-argument # pylint: disable=unused-argument
# TE computes GELU using float32 precision so wrap the GELU subgraph with # TE computes GELU using float32 precision so wrap the GELU subgraph with
# conversion to/from float32. # conversion to/from float32.
gelu = compute_in_fp32(g, inputs, torch.onnx.symbolic_opset9.gelu, "tanh") gelu = compute_in_fp32(g, inputs, torch.onnx.symbolic_opset9.gelu, "tanh")
if scale_inv: if scale:
gelu = quantize(g, gelu, scale_inv, fp8_tensor) gelu = quantize(g, gelu, scale, fp8_tensor)
return gelu return gelu
@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i") @symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i")
def onnx_fp8_relu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): def onnx_fp8_relu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for fp8_relu""" """ONNX graph for fp8_relu"""
# pylint: disable=unused-argument # pylint: disable=unused-argument
relu = compute_in_fp32(g, inputs, torch.onnx.symbolic_opset9.relu) relu = compute_in_fp32(g, inputs, torch.onnx.symbolic_opset9.relu)
if scale_inv: if scale:
relu = quantize(g, relu, scale_inv, fp8_tensor) relu = quantize(g, relu, scale, fp8_tensor)
return relu return relu
...@@ -178,13 +178,13 @@ def onnx_swiglu(g: jit_utils.GraphContext, inp, dim): ...@@ -178,13 +178,13 @@ def onnx_swiglu(g: jit_utils.GraphContext, inp, dim):
return g.op("Mul", g.op("Sigmoid", first), second) return g.op("Mul", g.op("Sigmoid", first), second)
@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i") @symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i")
def onnx_fp8_swiglu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): def onnx_fp8_swiglu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for fp8_swiglu""" """ONNX graph for fp8_swiglu"""
# pylint: disable=unused-argument # pylint: disable=unused-argument
swiglu = compute_in_fp32(g, inputs, onnx_swiglu, 1) swiglu = compute_in_fp32(g, inputs, onnx_swiglu, 1)
if scale_inv: if scale:
swiglu = quantize(g, swiglu, scale_inv, fp8_tensor) swiglu = quantize(g, swiglu, scale, fp8_tensor)
return swiglu return swiglu
...@@ -199,13 +199,13 @@ def onnx_reglu(g: jit_utils.GraphContext, inp, dim): ...@@ -199,13 +199,13 @@ def onnx_reglu(g: jit_utils.GraphContext, inp, dim):
return g.op("Mul", g.op("Relu", first), second) return g.op("Mul", g.op("Relu", first), second)
@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i") @symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i")
def onnx_fp8_reglu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): def onnx_fp8_reglu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for fp8_reglu""" """ONNX graph for fp8_reglu"""
# pylint: disable=unused-argument # pylint: disable=unused-argument
reglu = compute_in_fp32(g, inputs, onnx_reglu, 1) reglu = compute_in_fp32(g, inputs, onnx_reglu, 1)
if scale_inv: if scale:
reglu = quantize(g, reglu, scale_inv, fp8_tensor) reglu = quantize(g, reglu, scale, fp8_tensor)
return reglu return reglu
...@@ -221,13 +221,13 @@ def onnx_geglu(g: jit_utils.GraphContext, inp, dim): ...@@ -221,13 +221,13 @@ def onnx_geglu(g: jit_utils.GraphContext, inp, dim):
return g.op("Mul", first_gelu, second) return g.op("Mul", first_gelu, second)
@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i") @symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i")
def onnx_fp8_geglu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): def onnx_fp8_geglu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for fp8_geglu""" """ONNX graph for fp8_geglu"""
# pylint: disable=unused-argument # pylint: disable=unused-argument
geglu = compute_in_fp32(g, inputs, onnx_geglu, 1) geglu = compute_in_fp32(g, inputs, onnx_geglu, 1)
if scale_inv: if scale:
geglu = quantize(g, geglu, scale_inv, fp8_tensor) geglu = quantize(g, geglu, scale, fp8_tensor)
return geglu return geglu
...@@ -245,7 +245,7 @@ def onnx_fp8_geglu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): ...@@ -245,7 +245,7 @@ def onnx_fp8_geglu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
"v", "v",
"fs", "fs",
"i", "i",
"fs", "v",
"v", "v",
"i", "i",
"v", "v",
...@@ -330,7 +330,7 @@ def _ones_like(g, inp, dtype): ...@@ -330,7 +330,7 @@ def _ones_like(g, inp, dtype):
return one return one
@symbolic_helper.parse_args("v", "v", "v", "f", "v", "v", "fs", "i", "i", "i", "b") @symbolic_helper.parse_args("v", "v", "v", "f", "fs", "v", "v", "i", "i", "i", "b")
def onnx_layernorm_fwd_fp8( def onnx_layernorm_fwd_fp8(
g, g,
inputs, inputs,
...@@ -355,7 +355,7 @@ def onnx_layernorm_fwd_fp8( ...@@ -355,7 +355,7 @@ def onnx_layernorm_fwd_fp8(
bias = g.op("Cast", bias, to_i=inp_dtype) bias = g.op("Cast", bias, to_i=inp_dtype)
ln = onnx_layernorm_fwd(g, inputs, weight, bias, eps, sm_margin, zero_centered_gamma) ln = onnx_layernorm_fwd(g, inputs, weight, bias, eps, sm_margin, zero_centered_gamma)
fp8_ln = quantize(g, ln, scale_inv, fp8_tensor) fp8_ln = quantize(g, ln, scale, fp8_tensor)
return fp8_ln return fp8_ln
...@@ -391,7 +391,7 @@ def onnx_layernorm_fwd(g, inputs, weight, bias, eps, sm_margin, zero_centered_ga ...@@ -391,7 +391,7 @@ def onnx_layernorm_fwd(g, inputs, weight, bias, eps, sm_margin, zero_centered_ga
return ln return ln
@symbolic_helper.parse_args("v", "v", "f", "v", "v", "fs", "i", "i", "i", "b") @symbolic_helper.parse_args("v", "v", "f", "fs", "v", "v", "i", "i", "i", "b")
def onnx_rmsnorm_fwd_fp8( def onnx_rmsnorm_fwd_fp8(
g, g,
inputs, inputs,
...@@ -413,7 +413,7 @@ def onnx_rmsnorm_fwd_fp8( ...@@ -413,7 +413,7 @@ def onnx_rmsnorm_fwd_fp8(
weight = g.op("Cast", weight, to_i=inp_dtype) weight = g.op("Cast", weight, to_i=inp_dtype)
ln = onnx_rmsnorm_fwd(g, inputs, weight, eps, sm_margin, zero_centered_gamma) ln = onnx_rmsnorm_fwd(g, inputs, weight, eps, sm_margin, zero_centered_gamma)
fp8_ln = quantize(g, ln, scale_inv, fp8_tensor) fp8_ln = quantize(g, ln, scale, fp8_tensor)
return fp8_ln return fp8_ln
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment