Unverified Commit 207b231e authored by Tian Zheng's avatar Tian Zheng Committed by GitHub
Browse files

First step of PaddlePaddle integration (#249)



* First step of PaddlePaddle integration
- Add build option for paddle
- Add basic test framework
- Add 3 basic operators: cast_from_fp8, cast_to_fp8, gemm
Signed-off-by: default avatarTian Zheng <tizheng@nvidia.com>
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Fix review comments
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Support paddle build
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Add paddle build support for new building framework
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Fix review comments
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Clean up build process for Paddle stub file
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Minor fixes
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Fix pylint "wrong-import-order" warning
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Fix review comments
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

* Skip BF16 GEMM tests for unsupported arch
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>

---------
Signed-off-by: default avatarTian Zheng <tizheng@nvidia.com>
Signed-off-by: default avatarTian Zheng (Engrg-Hardware 1) <tizheng@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 48b31ca9
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# Stop searching for additional config files.
set noparent
# Limit line length.
linelength=100
# Ignore the following errors.
filter=-build/include_subdir
filter=-build/namespaces
filter=-readability/todo
filter=-build/header_guard
filter=-build/include
filter=-build/c++11
[MASTER]
extension-pkg-whitelist=transformer_engine_paddle
disable=too-many-locals,
invalid-name,
too-many-arguments,
abstract-method,
arguments-differ,
too-many-instance-attributes,
unsubscriptable-object,
import-outside-toplevel,
too-many-statements,
import-error,
too-many-lines,
use-maxsplit-arg,
protected-access,
pointless-string-statement,
cyclic-import,
duplicate-code,
no-member,
attribute-defined-outside-init,
global-statement,
too-many-branches,
global-variable-not-assigned,
redefined-argument-from-local
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -e
: "${TE_PATH:=/opt/transformerengine}"
pip install cpplint==1.6.0 pylint==2.13.5
if [ -z "${PYTHON_ONLY}" ]
then
cp $TE_PATH/qa/L0_paddle_lint/CPPLINT.cfg $TE_PATH
cd $TE_PATH
echo "Checking common API headers"
cpplint --root transformer_engine/common/include --recursive transformer_engine/common/include
echo "Checking C++ files"
cpplint --recursive --exclude=transformer_engine/common/include transformer_engine/common
cpplint --recursive transformer_engine/paddle
fi
if [ -z "${CPP_ONLY}" ]
then
cp $TE_PATH/qa/L0_paddle_lint/pylintrc $TE_PATH
cd $TE_PATH
echo "Checking Python files"
pylint --recursive=y transformer_engine/common transformer_engine/paddle
fi
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -xe
: ${TE_PATH:=/opt/transformerengine}
pytest -Wignore -v $TE_PATH/tests/paddle
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
import ctypes
from functools import lru_cache from functools import lru_cache
import os import os
from pathlib import Path from pathlib import Path
...@@ -202,7 +203,7 @@ def with_userbuffers() -> bool: ...@@ -202,7 +203,7 @@ def with_userbuffers() -> bool:
def frameworks() -> List[str]: def frameworks() -> List[str]:
"""DL frameworks to build support for""" """DL frameworks to build support for"""
_frameworks: List[str] = [] _frameworks: List[str] = []
supported_frameworks = ["pytorch", "jax", "tensorflow"] supported_frameworks = ["pytorch", "jax", "tensorflow", "paddle"]
# Check environment variable # Check environment variable
if os.getenv("NVTE_FRAMEWORK"): if os.getenv("NVTE_FRAMEWORK"):
...@@ -234,6 +235,12 @@ def frameworks() -> List[str]: ...@@ -234,6 +235,12 @@ def frameworks() -> List[str]:
pass pass
else: else:
_frameworks.append("tensorflow") _frameworks.append("tensorflow")
try:
import paddle
except ImportError:
pass
else:
_frameworks.append("paddle")
# Special framework names # Special framework names
if "all" in _frameworks: if "all" in _frameworks:
...@@ -295,6 +302,9 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: ...@@ -295,6 +302,9 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
add_unique(setup_reqs, "pybind11") add_unique(setup_reqs, "pybind11")
add_unique(install_reqs, "tensorflow") add_unique(install_reqs, "tensorflow")
add_unique(test_reqs, ["keras", "tensorflow_datasets"]) add_unique(test_reqs, ["keras", "tensorflow_datasets"])
if "paddle" in frameworks():
add_unique(install_reqs, "paddlepaddle-gpu")
add_unique(test_reqs, "numpy")
return setup_reqs, install_reqs, test_reqs return setup_reqs, install_reqs, test_reqs
...@@ -359,6 +369,8 @@ class CMakeExtension(setuptools.Extension): ...@@ -359,6 +369,8 @@ class CMakeExtension(setuptools.Extension):
# PyTorch extension modules require special handling # PyTorch extension modules require special handling
if "pytorch" in frameworks(): if "pytorch" in frameworks():
from torch.utils.cpp_extension import BuildExtension from torch.utils.cpp_extension import BuildExtension
elif "paddle" in frameworks():
from paddle.utils.cpp_extension import BuildExtension
else: else:
from setuptools.command.build_ext import build_ext as BuildExtension from setuptools.command.build_ext import build_ext as BuildExtension
...@@ -384,6 +396,15 @@ class CMakeBuildExtension(BuildExtension): ...@@ -384,6 +396,15 @@ class CMakeBuildExtension(BuildExtension):
install_dir=install_dir, install_dir=install_dir,
) )
# Paddle requires linker search path for libtransformer_engine.so
paddle_ext = None
if "paddle" in frameworks():
for ext in self.extensions:
if "paddle" in ext.name:
ext.library_dirs.append(self.build_lib)
paddle_ext = ext
break
# Build non-CMake extensions as usual # Build non-CMake extensions as usual
all_extensions = self.extensions all_extensions = self.extensions
self.extensions = [ self.extensions = [
...@@ -393,6 +414,34 @@ class CMakeBuildExtension(BuildExtension): ...@@ -393,6 +414,34 @@ class CMakeBuildExtension(BuildExtension):
super().run() super().run()
self.extensions = all_extensions self.extensions = all_extensions
# Manually write stub file for Paddle extension
if paddle_ext is not None:
# Load libtransformer_engine.so to avoid linker errors
for path in Path(self.build_lib).iterdir():
if path.name.startswith("libtransformer_engine."):
ctypes.CDLL(str(path), mode=ctypes.RTLD_GLOBAL)
# Figure out stub file path
module_name = paddle_ext.name
assert module_name.endswith("_pd_"), \
"Expected Paddle extension module to end with '_pd_'"
stub_name = module_name[:-4] # remove '_pd_'
stub_path = os.path.join(self.build_lib, stub_name + ".py")
# Figure out library name
# Note: This library doesn't actually exist. Paddle
# internally reinserts the '_pd_' suffix.
so_path = self.get_ext_fullpath(module_name)
_, so_ext = os.path.splitext(so_path)
lib_name = stub_name + so_ext
# Write stub file
print(f"Writing Paddle stub for {lib_name} into file {stub_path}")
from paddle.utils.cpp_extension.extension_utils import custom_write_stub
custom_write_stub(lib_name, stub_path)
def setup_common_extension() -> CMakeExtension: def setup_common_extension() -> CMakeExtension:
"""Setup CMake extension for common library """Setup CMake extension for common library
...@@ -484,6 +533,69 @@ def setup_pytorch_extension() -> setuptools.Extension: ...@@ -484,6 +533,69 @@ def setup_pytorch_extension() -> setuptools.Extension:
) )
def setup_paddle_extension() -> setuptools.Extension:
"""Setup CUDA extension for Paddle support"""
# Source files
src_dir = root_path / "transformer_engine" / "paddle" / "csrc"
sources = [
src_dir / "extensions.cu",
src_dir / "common.cpp",
src_dir / "custom_ops.cu",
]
# Header files
include_dirs = [
root_path / "transformer_engine" / "common" / "include",
root_path / "transformer_engine" / "paddle" / "csrc",
]
# Compiler flags
cxx_flags = ["-O3"]
nvcc_flags = [
"-O3",
"-gencode",
"arch=compute_70,code=sm_70",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
]
# Version-dependent CUDA options
try:
version = cuda_version()
except FileNotFoundError:
print("Could not determine CUDA Toolkit version")
else:
if version >= (11, 2):
nvcc_flags.extend(["--threads", "4"])
if version >= (11, 0):
nvcc_flags.extend(["-gencode", "arch=compute_80,code=sm_80"])
if version >= (11, 8):
nvcc_flags.extend(["-gencode", "arch=compute_90,code=sm_90"])
# Construct Paddle CUDA extension
sources = [str(path) for path in sources]
include_dirs = [str(path) for path in include_dirs]
from paddle.utils.cpp_extension import CUDAExtension
ext = CUDAExtension(
sources=sources,
include_dirs=include_dirs,
libraries=["transformer_engine"],
extra_compile_args={
"cxx": cxx_flags,
"nvcc": nvcc_flags,
},
)
ext.name = "transformer_engine_paddle_pd_"
return ext
def main(): def main():
# Submodules to install # Submodules to install
...@@ -499,6 +611,9 @@ def main(): ...@@ -499,6 +611,9 @@ def main():
if "pytorch" in frameworks(): if "pytorch" in frameworks():
ext_modules.append(setup_pytorch_extension()) ext_modules.append(setup_pytorch_extension())
if "paddle" in frameworks():
ext_modules.append(setup_paddle_extension())
# Configure package # Configure package
setuptools.setup( setuptools.setup(
name="transformer_engine", name="transformer_engine",
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Test basic installation of Paddle extensions"""
def test_import():
"""
Test if Paddle extension can be imported normally
"""
import transformer_engine.paddle # pylint: disable=unused-import
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Test TE operators"""
import pytest
import paddle
from utils import assert_allclose, create_fp8_meta
import transformer_engine # pylint: disable=unused-import
import transformer_engine_paddle as tex # pylint: disable=wrong-import-order
from transformer_engine.paddle.cpp_extensions import cast_to_fp8, cast_from_fp8, gemm, fp8_gemm
from transformer_engine.paddle.fp8 import is_fp8_available
paddle.seed(10)
GEMM_CASES = [(256, 256, 512), (32, 32, 32), (16384, 1024, 2816), (16384, 2816, 1024),
(16384, 1024, 1024)]
is_fp8_supported, reason = is_fp8_available()
def test_quantize_dequantize():
"""
Test cast_to_fp8 and cast_from_fp8
"""
a = paddle.rand(shape=(32, 32), dtype='float32')
# Init fp8_meta
fp8_meta = create_fp8_meta(num_fp8_tensors=3, amax_history_len=10)
for fp8_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]:
a_fp8 = cast_to_fp8(a, fp8_meta, tex.FP8FwdTensors.GEMM1_OUTPUT, otype=fp8_dtype)
b = cast_from_fp8(a_fp8,
fp8_meta,
tex.FP8FwdTensors.GEMM1_OUTPUT,
itype=fp8_dtype,
otype=tex.DType.kFloat32)
assert_allclose(a, b, rtol=5e-2, atol=5e-2)
class TestGemm:
"""
Tests for gemm(cuBLASLt) operator
"""
@staticmethod
@pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0),
reason="BF16 GEMM requires Ampere+ GPU")
@pytest.mark.parametrize('m,n,k', GEMM_CASES)
def test_bf16(m, n, k):
"""
Test "TN" BF16 GEMM
"""
a = paddle.rand(shape=(m, k), dtype='bfloat16')
b = paddle.rand(shape=(n, k), dtype='bfloat16')
workspace = paddle.zeros(shape=[33_554_432], dtype='uint8')
ref_out = paddle.matmul(a, b.T)
# CublasLt inside tex.te_gemm assumes inputs are column major.
# Mathematically, A@B=C is equivalent to B^T@A^T=C^T, where X^T is the
# transpose of X.
# Here we perform "TN" GEMM in column major, i.e., b@a^T = C^T,
# which is equivalent to a@b^T = C in row major.
actual_out, _, _ = gemm(b, a, paddle.bfloat16, workspace, False, None, False, False, "TN",
None, None, False)
assert_allclose(actual_out, ref_out)
@staticmethod
@pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0),
reason="BF16 GEMM requires Ampere+ GPU")
@pytest.mark.parametrize('m,n,k', GEMM_CASES)
def test_bf16_inplace(m, n, k):
"""
Test "TN" BF16 GEMM, with accumulate=True
"""
min_val = -16
max_val = 16
a = paddle.rand(shape=(m, k), dtype='bfloat16')
b = paddle.rand(shape=(n, k), dtype='bfloat16')
c = paddle.cast(paddle.randint(min_val, max_val, shape=(m, n)), 'bfloat16')
workspace = paddle.zeros(shape=[33_554_432], dtype='uint8')
ref_out = c + paddle.matmul(a, b.T)
actual_out = paddle.clone(c)
_, _, _ = gemm(b, a, paddle.bfloat16, workspace, False, None, False, True, "TN", actual_out,
None, False)
assert_allclose(actual_out, ref_out, rtol=5e-2, atol=5e-2)
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('m,n,k', GEMM_CASES)
def test_fp8_randint(m, n, k):
"""
Test "TN" FP8 GEMM
"""
min_val = -8
max_val = 8
fp8_dtype = tex.DType.kFloat8E4M3
out_dtype = paddle.float32
fp8_meta = create_fp8_meta(num_fp8_tensors=3, amax_history_len=10)
a = paddle.cast(paddle.randint(min_val, max_val, shape=(m, k)), 'float32')
a_casted = cast_to_fp8(a, fp8_meta, tex.FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype)
b = paddle.cast(paddle.randint(min_val, max_val, shape=(n, k)), 'float32')
b_casted = cast_to_fp8(b, fp8_meta, tex.FP8FwdTensors.GEMM1_WEIGHT, otype=fp8_dtype)
workspace = paddle.zeros(shape=[33_554_432], dtype='uint8')
ref_out = paddle.matmul(a, b.T)
actual_out = fp8_gemm(b_casted, fp8_meta.scale_inv, tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype, a_casted, fp8_meta.scale_inv,
tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype, out_dtype, workspace)
assert_allclose(actual_out, ref_out)
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Utils for testing"""
import paddle
import numpy as np
import transformer_engine # pylint: disable=unused-import
import transformer_engine_paddle as tex # pylint: disable=wrong-import-order
def create_fp8_meta(num_fp8_tensors, amax_history_len):
"""
Create and initialize FP8TensorMeta
"""
fp8_meta = tex.FP8TensorMeta()
fp8_meta.scale = paddle.ones(num_fp8_tensors, dtype='float32')
fp8_meta.scale_inv = paddle.ones(num_fp8_tensors, dtype='float32')
fp8_meta.amax_history = paddle.zeros((amax_history_len, num_fp8_tensors), dtype='float32')
return fp8_meta
def assert_allclose(actual,
desired,
rtol=1e-05,
atol=1e-08,
equal_nan=True,
err_msg='',
verbose=True):
"""Compare two input paddle tensors"""
if isinstance(actual, paddle.Tensor):
actual = paddle.cast(actual, 'float32').numpy()
if isinstance(desired, paddle.Tensor):
desired = paddle.cast(desired, 'float32').numpy()
np.testing.assert_allclose(actual, desired, rtol, atol, equal_nan, err_msg, verbose)
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Transformer Engine bindings for Paddle"""
from .cpp_extensions import gemm, fp8_gemm, cast_to_fp8, cast_from_fp8
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Constants"""
import paddle
import transformer_engine_paddle as tex
"""
Map from paddle dtype to TE dtype
"""
TE_DType = {
paddle.uint8: tex.DType.kByte,
paddle.int32: tex.DType.kInt32,
paddle.float32: tex.DType.kFloat32,
paddle.float16: tex.DType.kFloat16,
paddle.bfloat16: tex.DType.kBFloat16,
}
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""TE FP8 extensions and GEMMs"""
from typing import Optional, Tuple, Union
import paddle
import transformer_engine_paddle as tex
from .constants import TE_DType
def gemm(
A: paddle.Tensor,
B: paddle.Tensor,
dtype: paddle.dtype,
workspace: paddle.Tensor,
gelu: bool = False,
gelu_input: Optional[paddle.Tensor] = None,
grad: bool = False,
accumulate: bool = False,
layout: str = "TN",
out: Optional[paddle.Tensor] = None,
bias: Optional[paddle.Tensor] = None,
use_bias: bool = False,
) -> Tuple[Union[paddle.Tensor, None], ...]:
"""Non FP8 GEMM."""
assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported."
transa = layout[0] == "T"
transb = layout[1] == "T"
return_output = False
if out is None:
out = paddle.empty(
shape=[
B.shape[1] if transb else B.shape[0],
A.shape[0] if transa else A.shape[1],
],
dtype=dtype,
)
return_output = True
if gelu and not grad:
gelu_input = paddle.empty_like(out, dtype=dtype)
elif not gelu:
gelu_input = None
if grad and use_bias:
grad_bias = paddle.empty(shape=[B.shape[1]], dtype=out.dtype)
else:
grad_bias = None
bias = bias if use_bias else None
assert A.dtype == dtype and B.dtype == dtype, \
f'Expected dtype={dtype}, but found A.dtype={A.dtype} and B.dtype={B.dtype}'
input_dtype = TE_DType[dtype]
output_dtype = TE_DType[out.dtype]
if use_bias:
bias_dtype = TE_DType[grad_bias.dtype] if grad else TE_DType[bias.dtype]
else:
bias_dtype = output_dtype
tex.te_gemm(
A,
None,
B,
None,
grad_bias if grad else bias,
out,
None, # out_scale
None, # out_amax
gelu_input,
workspace,
0, # A_index
0, # B_index
0, # D_index
int(input_dtype),
int(input_dtype),
int(output_dtype),
int(bias_dtype),
transa,
transb,
grad,
workspace.shape[0],
accumulate,
False, # use_split_accumulator
0, # math_sm_count
)
if return_output:
return out, grad_bias, gelu_input
return None, grad_bias, gelu_input
def fp8_gemm(
A: paddle.Tensor,
A_scale_inv: paddle.Tensor,
A_fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
A_dtype: tex.DType,
B: paddle.Tensor,
B_scale_inv: paddle.Tensor,
B_fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
B_dtype: tex.DType,
out_dtype: paddle.dtype,
workspace: paddle.Tensor,
gelu: bool = False,
accumulate: bool = False,
out: Optional[paddle.Tensor] = None,
out_index=None,
fp8_meta_tensor: tex.FP8TensorMeta = None,
bias: Optional[paddle.Tensor] = None,
use_bias: bool = False,
use_split_accumulator: bool = False,
D_dtype: Optional[tex.DType] = None,
) -> paddle.Tensor:
"""TN layout GEMM with fp8 inputs."""
if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]:
assert fp8_meta_tensor is not None and out_index is not None
return_output = False
if out is None:
out = paddle.empty(
shape=[
B.shape[0],
A.shape[0],
],
dtype=out_dtype,
)
return_output = True
# Use bfloat16 as default bias_dtype
bias_dtype = paddle.bfloat16 if bias is None else bias.dtype
if gelu:
gelu_input = paddle.empty_like(out, dtype=bias_dtype)
else:
gelu_input = None
bias_dtype = TE_DType[bias_dtype]
out_dtype = TE_DType[out.dtype] if D_dtype is None else D_dtype
tex.te_gemm(
A,
A_scale_inv,
B,
B_scale_inv,
bias if use_bias else None,
out,
None if out_index is None else fp8_meta_tensor.scale,
None if out_index is None else fp8_meta_tensor.amax_history,
gelu_input, # this is pre_gelu_out
workspace,
int(A_fp8_tensor),
int(B_fp8_tensor),
0 if out_index is None else out_index,
int(A_dtype),
int(B_dtype),
int(out_dtype),
int(bias_dtype),
True, # transa
False, # transb
False, # grad
workspace.shape[0],
accumulate,
use_split_accumulator,
0, # math_sm_count
)
if return_output:
if gelu:
return out, gelu_input
return out
if gelu:
return gelu_input
return None
def cast_to_fp8(
inp: paddle.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
) -> paddle.Tensor:
"""Cast input to FP8"""
out, _, _ = tex.cast_to_fp8(
inp,
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
int(fp8_tensor),
int(otype),
)
return out
def cast_from_fp8(
inp: paddle.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
itype: tex.DType,
otype: tex.DType,
) -> paddle.Tensor:
"""Cast input from FP8"""
return tex.cast_from_fp8(
inp,
fp8_meta_tensor.scale_inv,
int(fp8_tensor),
int(itype),
int(otype),
)
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "common.h"
namespace transformer_engine {
namespace paddle_ext {
TensorWrapper MakeNvteTensor(void *data_ptr, const std::vector<size_t> &shape, const DType type) {
return TensorWrapper(data_ptr, shape, type);
}
TensorWrapper MakeNvteTensor(void *data_ptr, const std::vector<size_t> &shape, const DType type,
void *amax_ptr, void *scale_ptr, void *scale_inv_ptr) {
return TensorWrapper(data_ptr, shape, type, reinterpret_cast<float *>(amax_ptr),
reinterpret_cast<float *>(scale_ptr),
reinterpret_cast<float *>(scale_inv_ptr));
}
TensorWrapper MakeNvteTensor(const paddle::Tensor &tensor) {
return MakeNvteTensor(const_cast<void *>(tensor.data()), GetShapeArray(tensor),
Paddle2NvteDType(tensor.dtype()));
}
} // namespace paddle_ext
} // namespace transformer_engine
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#pragma once
#include <cublasLt.h>
#include <transformer_engine/cast.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/logging.h>
#include <transformer_engine/transformer_engine.h>
#include <vector>
#include "paddle/extension.h"
namespace transformer_engine {
namespace paddle_ext {
// Each tensor here is shape (N, ) holding all scaling
// data for a single FP8 block, e.g. LayerNormLinear
class FP8TensorMeta {
public:
paddle::Tensor scale;
paddle::Tensor scale_inv;
paddle::Tensor amax_history;
};
// Used as named indices on the `scale`, `scale_inv`,
// and `amax` tensors in the `FP8TensorMeta` class.
enum class FP8FwdTensors {
GEMM1_INPUT = 0,
GEMM1_WEIGHT = 1,
GEMM1_OUTPUT = 2,
GEMM2_INPUT = 3,
GEMM2_WEIGHT = 4,
GEMM2_OUTPUT = 5
};
// Used as named indices on the `scale`, `scale_inv`,
// and `amax` tensors in the `FP8TensorMeta` class.
enum class FP8BwdTensors { GRAD_OUTPUT1 = 0, GRAD_INPUT1 = 1, GRAD_OUTPUT2 = 2, GRAD_INPUT2 = 3 };
// Paddle Tensor Utils
template <typename T>
inline const void *GetDataPtr(const paddle::Tensor &x, int64_t index) {
if (index < 0 || index >= x.numel()) {
NVTE_ERROR("Index out of bound");
}
return reinterpret_cast<const void *>(x.data<T>() + static_cast<size_t>(index));
}
template <typename T>
inline void *GetDataPtr(paddle::Tensor &x, int64_t index) { // NOLINT
if (index < 0 || index >= x.numel()) {
NVTE_ERROR("Index out of bound");
}
return reinterpret_cast<void *>(x.data<T>() + static_cast<size_t>(index));
}
template <typename T>
inline const void *GetOptionalDataPtr(const paddle::optional<paddle::Tensor> &x, int64_t index) {
return x ? GetDataPtr<T>(*x, index) : nullptr;
}
template <typename T>
inline void *GetOptionalDataPtr(paddle::optional<paddle::Tensor> &x, int64_t index) { // NOLINT
return x ? GetDataPtr<T>(*x, index) : nullptr;
}
inline const void *GetOptionalDataPtr(const paddle::optional<paddle::Tensor> &x) {
return x ? x->data() : nullptr;
}
inline void *GetOptionalDataPtr(paddle::optional<paddle::Tensor> &x) { // NOLINT
return x ? x->data() : nullptr;
}
inline std::vector<size_t> GetShapeArray(const paddle::optional<paddle::Tensor> &x) {
if (x) return GetShapeArray(x.get());
return {0};
}
inline std::vector<size_t> GetShapeArray(const paddle::Tensor &x) {
std::vector<size_t> shapes;
for (auto dim : x.shape()) {
shapes.push_back(static_cast<size_t>(dim));
}
return shapes;
}
// DType Utils
inline paddle::DataType Nvte2PaddleDType(DType t) {
switch (t) {
case DType::kInt32:
case DType::kFloat32:
return paddle::DataType::FLOAT32;
case DType::kFloat16:
return paddle::DataType::FLOAT16;
case DType::kBFloat16:
return paddle::DataType::BFLOAT16;
case DType::kByte:
case DType::kFloat8E4M3:
case DType::kFloat8E5M2:
return paddle::DataType::UINT8;
default:
NVTE_ERROR("Invalid type");
}
}
inline DType Paddle2NvteDType(paddle::DataType t) {
switch (t) {
case paddle::DataType::FLOAT16:
return DType::kFloat16;
case paddle::DataType::FLOAT32:
return DType::kFloat32;
case paddle::DataType::BFLOAT16:
return DType::kBFloat16;
case paddle::DataType::BOOL:
return DType::kByte;
case paddle::DataType::UINT8:
return DType::kByte;
case paddle::DataType::INT32:
return DType::kInt32;
case paddle::DataType::INT64:
return DType::kInt64;
default:
NVTE_ERROR("Invalid type");
}
}
inline DType Int2NvteDType(int64_t dtype) {
if (dtype >= 0 && dtype < static_cast<int64_t>(DType::kNumTypes)) {
return static_cast<DType>(dtype);
} else {
NVTE_ERROR("Type not supported.");
}
}
// NVTE Tensor Utils
TensorWrapper MakeNvteTensor(void *data_ptr, const std::vector<size_t> &shape, const DType type);
TensorWrapper MakeNvteTensor(void *data_ptr, const std::vector<size_t> &shape, const DType type,
void *amax_ptr, void *scale_ptr, void *scale_inv_ptr);
TensorWrapper MakeNvteTensor(const paddle::Tensor &tensor);
} // namespace paddle_ext
} // namespace transformer_engine
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <vector>
#include "common.h"
namespace transformer_engine {
namespace paddle_ext {
std::vector<paddle::Tensor> cast_to_fp8(const paddle::Tensor &input, const paddle::Tensor &scale,
paddle::Tensor &amax, paddle::Tensor &scale_inv, // NOLINT
int64_t index, int64_t otype) {
auto shape = GetShapeArray(input);
auto output = paddle::empty_like(input, Nvte2PaddleDType(Int2NvteDType(otype)));
auto input_cu = MakeNvteTensor(input);
auto output_cu = MakeNvteTensor(
output.data(), shape, Int2NvteDType(otype), GetDataPtr<float>(amax, index),
const_cast<void *>(GetDataPtr<float>(scale, index)), GetDataPtr<float>(scale_inv, index));
nvte_fp8_quantize(input_cu.data(), output_cu.data(), input.stream());
return {output};
}
std::vector<paddle::Tensor> cast_from_fp8(const paddle::Tensor &input,
const paddle::Tensor &scale_inv, int64_t index,
int64_t itype, int64_t otype) {
auto shape = GetShapeArray(input);
auto output = paddle::empty_like(input, Nvte2PaddleDType(Int2NvteDType(otype)));
auto input_cu =
MakeNvteTensor(const_cast<void *>(input.data()), shape, Int2NvteDType(itype), nullptr,
nullptr, const_cast<void *>(GetDataPtr<float>(scale_inv, index)));
auto output_cu = MakeNvteTensor(output);
nvte_fp8_dequantize(input_cu.data(), output_cu.data(), input.stream());
return {output};
}
void te_gemm(const paddle::Tensor &A, const paddle::optional<paddle::Tensor> &A_scale_inverse,
const paddle::Tensor &B, const paddle::optional<paddle::Tensor> &B_scale_inverse,
const paddle::optional<paddle::Tensor> &bias, paddle::Tensor &D, // NOLINT
paddle::optional<paddle::Tensor> &D_scale, // NOLINT
paddle::optional<paddle::Tensor> &D_amax, // NOLINT
paddle::optional<paddle::Tensor> &pre_gelu_out, paddle::Tensor &workspace, // NOLINT
int64_t A_index, int64_t B_index, int64_t D_index, int64_t A_type, int64_t B_type,
int64_t D_type, int64_t bias_type, bool transa, bool transb, bool grad,
int64_t workspace_size, bool accumulate, bool use_split_accumulator,
int64_t math_sm_count) {
auto te_A = MakeNvteTensor(
const_cast<void *>(A.data()), GetShapeArray(A), Int2NvteDType(A_type), nullptr, nullptr,
const_cast<void *>(GetOptionalDataPtr<float>(A_scale_inverse, A_index)));
auto te_B = MakeNvteTensor(
const_cast<void *>(B.data()), GetShapeArray(B), Int2NvteDType(B_type), nullptr, nullptr,
const_cast<void *>(GetOptionalDataPtr<float>(B_scale_inverse, B_index)));
auto te_D = MakeNvteTensor(D.data(), GetShapeArray(D), Int2NvteDType(D_type),
GetOptionalDataPtr<float>(D_amax, D_index),
GetOptionalDataPtr<float>(D_scale, D_index), nullptr);
auto te_bias = MakeNvteTensor(const_cast<void *>(GetOptionalDataPtr(bias)), GetShapeArray(bias),
Int2NvteDType(bias_type));
DType gelu_dtype =
pre_gelu_out ? Paddle2NvteDType(pre_gelu_out->dtype()) : Int2NvteDType(D_type);
auto te_pre_gelu_out =
MakeNvteTensor(GetOptionalDataPtr(pre_gelu_out), GetShapeArray(pre_gelu_out), gelu_dtype);
auto te_workspace =
MakeNvteTensor(workspace.data(), {static_cast<size_t>(workspace_size)}, DType::kByte);
nvte_cublas_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(), te_pre_gelu_out.data(),
transa, transb, grad, te_workspace.data(), accumulate, use_split_accumulator,
math_sm_count, A.stream());
}
} // namespace paddle_ext
} // namespace transformer_engine
PD_BUILD_OP(te_gemm)
.Inputs({"A", paddle::Optional("A_scale_inverse"), "B", paddle::Optional("B_scale_inverse"),
paddle::Optional("bias"), "_D", paddle::Optional("_D_scale"),
paddle::Optional("_D_amax"), paddle::Optional("_pre_gelu_out"), "_workspace"})
.Outputs({"D", paddle::Optional("D_scale"), paddle::Optional("D_amax"),
paddle::Optional("pre_gelu_out"), "workspace"})
.Attrs({"A_index: int64_t", "B_index: int64_t", "D_index: int64_t", "A_type: int64_t",
"B_type: int64_t", "D_type: int64_t", "bias_type: int64_t", "transa: bool",
"transb: bool", "grad: bool", "workspace_size: int64_t", "accumulate: bool",
"use_split_accumulator: bool", "math_sm_count: int64_t"})
.SetInplaceMap({{"_D", "D"},
{paddle::Optional("_D_scale"), paddle::Optional("D_scale")},
{paddle::Optional("_D_amax"), paddle::Optional("D_amax")},
{paddle::Optional("_pre_gelu_out"), paddle::Optional("pre_gelu_out")},
{"_workspace", "workspace"}})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_gemm));
PD_BUILD_OP(cast_to_fp8)
.Inputs({"Input", "Scale", "_Amax", "_ScaleInv"})
.Outputs({"Output", "Amax", "ScaleInv"})
.Attrs({"index: int64_t", "otype: int64_t"})
.SetInplaceMap({{"_Amax", "Amax"}, {"_ScaleInv", "ScaleInv"}})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::cast_to_fp8));
PD_BUILD_OP(cast_from_fp8)
.Inputs({"Input", "ScaleInv"})
.Outputs({"Output"})
.Attrs({"index: int64_t", "itype: int64_t", "otype: int64_t"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::cast_from_fp8));
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "common.h"
namespace transformer_engine {
namespace paddle_ext {
size_t get_cublasLt_version() { return cublasLtGetVersion(); }
PYBIND11_MODULE(transformer_engine_paddle, m) {
// Misc
m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version");
// Data structures
py::class_<FP8TensorMeta>(m, "FP8TensorMeta")
.def(py::init<>())
.def_readwrite("scale", &FP8TensorMeta::scale)
.def_readwrite("scale_inv", &FP8TensorMeta::scale_inv)
.def_readwrite("amax_history", &FP8TensorMeta::amax_history);
py::enum_<DType>(m, "DType", py::module_local())
.value("kByte", DType::kByte)
.value("kInt32", DType::kInt32)
.value("kFloat32", DType::kFloat32)
.value("kFloat16", DType::kFloat16)
.value("kBFloat16", DType::kBFloat16)
.value("kFloat8E4M3", DType::kFloat8E4M3)
.value("kFloat8E5M2", DType::kFloat8E5M2);
py::enum_<FP8FwdTensors>(m, "FP8FwdTensors")
.value("GEMM1_INPUT", FP8FwdTensors::GEMM1_INPUT)
.value("GEMM1_WEIGHT", FP8FwdTensors::GEMM1_WEIGHT)
.value("GEMM1_OUTPUT", FP8FwdTensors::GEMM1_OUTPUT)
.value("GEMM2_INPUT", FP8FwdTensors::GEMM2_INPUT)
.value("GEMM2_WEIGHT", FP8FwdTensors::GEMM2_WEIGHT)
.value("GEMM2_OUTPUT", FP8FwdTensors::GEMM2_OUTPUT);
py::enum_<FP8BwdTensors>(m, "FP8BwdTensors")
.value("GRAD_OUTPUT1", FP8BwdTensors::GRAD_OUTPUT1)
.value("GRAD_INPUT1", FP8BwdTensors::GRAD_INPUT1)
.value("GRAD_OUTPUT2", FP8BwdTensors::GRAD_OUTPUT2)
.value("GRAD_INPUT2", FP8BwdTensors::GRAD_INPUT2);
}
} // namespace paddle_ext
} // namespace transformer_engine
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""FP8 utilities for TransformerEngine"""
from typing import Tuple
import paddle
import transformer_engine_paddle as tex
_is_fp8_available = None
_reason_for_no_fp8 = ""
def _check_fp8_support() -> Tuple[bool, str]:
"""Return if fp8 support is available"""
# Check GPU arch
arch = paddle.device.cuda.get_device_capability()
if arch >= (9, 0): # hopper and above
return True, ""
if arch < (8, 9): # pre-ada
return False, "Device compute capability 8.9 or higher required for FP8 execution."
# Special handling for Ada
if tex.get_cublasLt_version() < 120103:
return False, "CublasLt version 12.1.3.x or higher required for FP8 execution on Ada."
if not paddle.version.cuda():
return False, "Cuda version 12.1 or higher required for FP8 execution on Ada."
if tuple(int(v) for v in paddle.version.cuda().split(".")) < (12, 1):
return False, "Cuda version 12.1 or higher required for FP8 execution on Ada."
return True, ""
def is_fp8_available() -> Tuple[bool, str]:
"""Return if fp8 support is available"""
global _is_fp8_available, _reason_for_no_fp8
if _is_fp8_available is None:
_is_fp8_available, _reason_for_no_fp8 = _check_fp8_support()
return _is_fp8_available, _reason_for_no_fp8
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment