Unverified Commit f6b766bd authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Build custom ORT ops before running ONNX export tests (#1252)



* Build custom ORT ops before running ONNX tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove ONNX from context parallelism tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Export ONNX ops that do compute in FP32

Matches internal impl of TE kernels.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add build script for custom ORT ops
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 54aa12a9
......@@ -6,7 +6,7 @@ set -e
: ${TE_PATH:=/opt/transformerengine}
pip install pytest==8.2.1 onnxruntime==1.13.1
pip install pytest==8.2.1
pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py
pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py
pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py
......@@ -15,7 +15,6 @@ PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v
pytest -v -s $TE_PATH/tests/pytorch/test_jit.py
NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py
pytest -v -s $TE_PATH/tests/pytorch/test_torch_save_load.py
pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py
......@@ -24,3 +23,9 @@ pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py
pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py
pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops_distributed.py
pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py
# Build custom ONNX extensions for ONNX export test
pip install onnxruntime==1.19.2
export CUSTOM_ORT_OPS_PATH=$TE_PATH/tests/pytorch/custom_ort_ops
bash $CUSTOM_ORT_OPS_PATH/build.sh
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
......@@ -6,5 +6,5 @@ set -e
: ${TE_PATH:=/opt/transformerengine}
pip install pytest==7.2.0 onnxruntime==1.13.1
pip install pytest==7.2.0
pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py
build
onnxruntime
libcustom_ort_ops.so
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
cmake_minimum_required(VERSION 3.21)
project(custom_ort_ops LANGUAGES CXX)
# Dependencies
find_package(CUDAToolkit REQUIRED)
set(ONNX_INCLUDE_DIR ${CMAKE_SOURCE_DIR}/onnxruntime/include)
if(NOT EXISTS "${ONNX_INCLUDE_DIR}")
message(FATAL_ERROR
"Could not find ONNX Runtime headers. "
"Please clone https://github.com/microsoft/onnxruntime "
"into TransformerEngine/tests/pytorch/onnx.")
endif()
include_directories(${ONNX_INCLUDE_DIR})
# Configure library
add_library(custom_ort_ops SHARED custom_op_library.cc)
target_link_libraries(custom_ort_ops PUBLIC CUDA::cudart)
target_include_directories(custom_ort_ops PUBLIC
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_include_directories(custom_ort_ops PRIVATE
${ONNX_INCLUDE_DIR}/onnxruntime
${ONNX_INCLUDE_DIR}/onnxruntime/core/session)
# Install library
install(TARGETS custom_ort_ops DESTINATION .)
# Custom ONNX Runtime operators for Transformer Engine tests
This directory contains code that builds custom ONNX operators for use
in Transformer Engine tests. It includes basic, non-performant
implementations of the FP8 quantization and dequantization operators
that are used when exporting Transformer Engine models to ONNX.
For more information, see [the ONNX Runtime reference for custom
operators](https://onnxruntime.ai/docs/reference/operators/add-custom-op.html).
Much of the code has been adapted from [an ONNX Runtime
test](https://github.com/microsoft/onnxruntime/blob/de93f40240459953a6e3bbb86b6ad83eaeab681f/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc).
## Usage
* Build the custom operators:
```bash
$ bash TransformerEngine/tests/pytorch/custom_ort_ops/build.sh
```
* Run the ONNX export tests with pytest:
```bash
$ python -m pytest TransformerEngine/tests/pytorch/test_onnx_export.py
```
\ No newline at end of file
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -ex
: ${CUSTOM_ORT_OPS_PATH=$(dirname $(realpath $0))}
cd ${CUSTOM_ORT_OPS_PATH}
# Download ONNX Runtime source
git clone --depth=1 -b rel-1.19.2 --single-branch https://github.com/microsoft/onnxruntime.git || true
# Configure and build with CMake
mkdir -p build
cmake -S . -B build -DCMAKE_INSTALL_PREFIX=.
cmake --build build --verbose
cmake --install build --verbose
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "custom_op_library.h"
#define ORT_API_MANUAL_INIT
#include "onnxruntime_c_api.h"
#include "onnxruntime_cxx_api.h"
#undef ORT_API_MANUAL_INIT
#include <exception>
#include <memory>
#include <mutex>
#include <utility>
#include <vector>
#include "core/common/common.h"
#include "core/session/onnxruntime_lite_custom_op.h"
#include <cuda_fp8.h>
namespace {
template <typename IType, typename OType, typename CType>
void Quantize(OrtKernelContext* context,
const Ort::Custom::Tensor<IType>& input,
const Ort::Custom::Tensor<CType>& scale_inv,
Ort::Custom::Tensor<unsigned char>& output) {
auto raw_input = input.Data();
auto raw_scale_inv = scale_inv.Data();
auto raw_output = reinterpret_cast<OType*>(output.Allocate(input.Shape()));
const auto rs = static_cast<CType>(raw_scale_inv[0]);
const size_t N = input.NumberOfElement();
for (size_t i = 0; i < N; ++i) {
const auto x = static_cast<CType>(raw_input[i]);
raw_output[i] = static_cast<OType>(x / rs);
}
}
template <typename IType, typename OType, typename CType>
void Dequantize(OrtKernelContext* context,
const Ort::Custom::Tensor<unsigned char>& input,
const Ort::Custom::Tensor<CType>& scale_inv,
Ort::Custom::Tensor<OType>& output) {
auto raw_input = reinterpret_cast<const IType*>(input.Data());
auto raw_scale_inv = scale_inv.Data();
auto raw_output = output.Allocate(input.Shape());
const auto rs = static_cast<CType>(raw_scale_inv[0]);
const size_t N = input.NumberOfElement();
for (size_t i = 0; i < N; ++i) {
const auto x = rs * static_cast<CType>(raw_input[i]);
raw_output[i] = static_cast<OType>(x);
}
}
static void AddOrtCustomOpDomainToContainer(Ort::CustomOpDomain&& domain) {
static std::vector<Ort::CustomOpDomain> ort_custom_op_domain_container;
static std::mutex ort_custom_op_domain_mutex;
std::lock_guard<std::mutex> lock(ort_custom_op_domain_mutex);
ort_custom_op_domain_container.push_back(std::move(domain));
}
} // namespace
OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api) {
Ort::Global<void>::api_ = api->GetApi(ORT_API_VERSION);
// Namespace for custom ops
static const char* c_OpDomain = "trt";
// Construct custom ops
static const std::unique_ptr<Ort::Custom::OrtLiteCustomOp> c_Quantize{
Ort::Custom::CreateLiteCustomOp("TRT_FP8QuantizeLinear",
"CPUExecutionProvider",
Quantize<float, __nv_fp8_e4m3, float>)
};
static const std::unique_ptr<Ort::Custom::OrtLiteCustomOp> c_Dequantize{
Ort::Custom::CreateLiteCustomOp("TRT_FP8DequantizeLinear",
"CPUExecutionProvider",
Dequantize<__nv_fp8_e4m3, float, float>)
};
// Register custom ops
OrtStatus* result = nullptr;
ORT_TRY {
Ort::CustomOpDomain domain{c_OpDomain};
domain.Add(c_Quantize.get());
domain.Add(c_Dequantize.get());
Ort::UnownedSessionOptions session_options(options);
session_options.Add(domain);
AddOrtCustomOpDomainToContainer(std::move(domain));
}
ORT_CATCH(const std::exception& e) {
ORT_HANDLE_EXCEPTION([&]() {
Ort::Status status{e};
result = status.release();
});
}
return result;
}
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#pragma once
#include "onnxruntime/core/session/onnxruntime_c_api.h"
#ifdef __cplusplus
extern "C" {
#endif
ORT_EXPORT OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api);
#ifdef __cplusplus
}
#endif
......@@ -72,7 +72,7 @@ OPSET = 17
assert OPSET >= TRILU_OPSET
# Shared library implementing custom FP8 Q/DQ operators for ONNX Runtime (ORT).
ORT_CUSTOM_OPS_LIB = os.path.join(TESTS_DIR, "./libcustom_ort_fp8_qdq_ops.so")
ORT_CUSTOM_OPS_LIB = os.path.join(TESTS_DIR, "custom_ort_ops", "libcustom_ort_ops.so")
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
skip_FP8 = pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
......@@ -85,7 +85,7 @@ all_normalizations = ["LayerNorm", "RMSNorm"]
@pytest.fixture()
def seed_default_rng():
"""Reseed the PRNG for test reproducibility"""
torch.random.seed()
torch.manual_seed(1234)
@pytest.fixture()
......
......@@ -146,89 +146,136 @@ def onnx_cast_from_fp8(g, inputs, scale_inv, fp8_tensor, itype, otype):
@symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i")
def onnx_fp8_gelu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
def onnx_fp8_gelu(g, inp, scale, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for fp8_gelu"""
# pylint: disable=unused-argument
# TE computes GELU using float32 precision so wrap the GELU subgraph with
# conversion to/from float32.
gelu = compute_in_fp32(g, inputs, torch.onnx.symbolic_opset9.gelu, "tanh")
dtype = get_TensorProtoDataType(inp)
if dtype != _type_utils.JitScalarType.FLOAT:
inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT)
out = torch.onnx.symbolic_opset9.gelu(g, inp, "tanh")
if scale:
gelu = quantize(g, gelu, scale, fp8_tensor)
return gelu
out = quantize(g, out, scale, fp8_tensor)
elif dtype != _type_utils.JitScalarType.FLOAT:
out = g.op("Cast", out, to_i=dtype)
return out
@symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i")
def onnx_fp8_relu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
def onnx_fp8_relu(g, inp, scale, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for fp8_relu"""
# pylint: disable=unused-argument
relu = compute_in_fp32(g, inputs, torch.onnx.symbolic_opset9.relu)
out = torch.onnx.symbolic_opset9.relu(g, inp)
if scale:
relu = quantize(g, relu, scale, fp8_tensor)
return relu
out = quantize(g, out, scale, fp8_tensor)
return out
@symbolic_helper.parse_args("v", "i")
def onnx_swiglu(g: jit_utils.GraphContext, inp, dim):
"""ONNX graph for swiglu"""
# Check dimensions
dim_size = symbolic_helper._get_tensor_dim_size(inp, dim)
if dim_size is not None:
assert dim_size % 2 == 0
# Perform compute in FP32
dtype = get_TensorProtoDataType(inp)
if dtype != _type_utils.JitScalarType.FLOAT:
inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT)
first, second = g.op("Split", inp, axis_i=dim, outputs=2)
return g.op("Mul", g.op("Sigmoid", first), second)
out = g.op("Mul", g.op("Sigmoid", first), second)
if dtype != _type_utils.JitScalarType.FLOAT:
out = g.op("Cast", out, to_i=dtype)
return out
@symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i")
def onnx_fp8_swiglu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
def onnx_fp8_swiglu(g, inp, scale, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for fp8_swiglu"""
# pylint: disable=unused-argument
swiglu = compute_in_fp32(g, inputs, onnx_swiglu, 1)
dtype = get_TensorProtoDataType(inp)
if dtype != _type_utils.JitScalarType.FLOAT:
inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT)
out = onnx_swiglu(g, inp, 1)
if scale:
swiglu = quantize(g, swiglu, scale, fp8_tensor)
return swiglu
out = quantize(g, out, scale, fp8_tensor)
elif dtype != _type_utils.JitScalarType.FLOAT:
out = g.op("Cast", out, to_i=dtype)
return out
@symbolic_helper.parse_args("v", "i")
def onnx_reglu(g: jit_utils.GraphContext, inp, dim):
"""ONNX graph for reglu"""
# Check dimensions
dim_size = symbolic_helper._get_tensor_dim_size(inp, dim)
if dim_size is not None:
assert dim_size % 2 == 0
# Perform compute in FP32
dtype = get_TensorProtoDataType(inp)
if dtype != _type_utils.JitScalarType.FLOAT:
inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT)
first, second = g.op("Split", inp, axis_i=dim, outputs=2)
return g.op("Mul", g.op("Relu", first), second)
out = g.op("Mul", g.op("Relu", first), second)
if dtype != _type_utils.JitScalarType.FLOAT:
out = g.op("Cast", out, to_i=dtype)
return out
@symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i")
def onnx_fp8_reglu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
def onnx_fp8_reglu(g, inp, scale, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for fp8_reglu"""
# pylint: disable=unused-argument
reglu = compute_in_fp32(g, inputs, onnx_reglu, 1)
dtype = get_TensorProtoDataType(inp)
if dtype != _type_utils.JitScalarType.FLOAT:
inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT)
out = onnx_reglu(g, inp, 1)
if scale:
reglu = quantize(g, reglu, scale, fp8_tensor)
return reglu
out = quantize(g, out, scale, fp8_tensor)
elif dtype != _type_utils.JitScalarType.FLOAT:
out = g.op("Cast", out, to_i=dtype)
return out
@symbolic_helper.parse_args("v", "i")
def onnx_geglu(g: jit_utils.GraphContext, inp, dim):
"""ONNX graph for geglu"""
# Check dimensions
dim_size = symbolic_helper._get_tensor_dim_size(inp, dim)
if dim_size is not None:
assert dim_size % 2 == 0
# Perform compute in FP32
dtype = get_TensorProtoDataType(inp)
if dtype != _type_utils.JitScalarType.FLOAT:
inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT)
first, second = g.op("Split", inp, axis_i=dim, outputs=2)
first_gelu = torch.onnx.symbolic_opset9.gelu(g, first, "tanh")
return g.op("Mul", first_gelu, second)
first = torch.onnx.symbolic_opset9.gelu(g, first, "tanh")
out = g.op("Mul", first, second)
if dtype != _type_utils.JitScalarType.FLOAT:
out = g.op("Cast", out, to_i=dtype)
return out
@symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i")
def onnx_fp8_geglu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
def onnx_fp8_geglu(g, inp, scale, amax, scale_inv, fp8_tensor, otype):
"""ONNX graph for fp8_geglu"""
# pylint: disable=unused-argument
geglu = compute_in_fp32(g, inputs, onnx_geglu, 1)
dtype = get_TensorProtoDataType(inp)
if dtype != _type_utils.JitScalarType.FLOAT:
inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT)
out = onnx_geglu(g, inp, 1)
if scale:
geglu = quantize(g, geglu, scale, fp8_tensor)
return geglu
out = quantize(g, out, scale, fp8_tensor)
elif dtype != _type_utils.JitScalarType.FLOAT:
out = g.op("Cast", out, to_i=dtype)
return out
@symbolic_helper.parse_args(
......@@ -394,7 +441,7 @@ def onnx_layernorm_fwd(g, inputs, weight, bias, eps, sm_margin, zero_centered_ga
@symbolic_helper.parse_args("v", "v", "f", "fs", "v", "v", "i", "i", "i", "b")
def onnx_rmsnorm_fwd_fp8(
g,
inputs,
inp,
weight,
eps,
scale,
......@@ -407,50 +454,54 @@ def onnx_rmsnorm_fwd_fp8(
):
"""ONNX graph for rmsnorm_fwd_fp8"""
# pylint: disable=unused-argument
inp_dtype = get_TensorProtoDataType(inputs)
if inp_dtype != get_TensorProtoDataType(weight):
weight = g.op("Cast", weight, to_i=inp_dtype)
ln = onnx_rmsnorm_fwd(g, inputs, weight, eps, sm_margin, zero_centered_gamma)
fp8_ln = quantize(g, ln, scale, fp8_tensor)
return fp8_ln
dtype = get_TensorProtoDataType(inp)
if dtype != _type_utils.JitScalarType.FLOAT:
inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT)
out = onnx_rmsnorm_fwd(g, inp, weight, eps, sm_margin, zero_centered_gamma)
out = quantize(g, out, scale, fp8_tensor)
return out
@symbolic_helper.parse_args("v", "v", "f", "i", "b")
def onnx_rmsnorm_fwd(g, inputs, weight, eps, sm_margin, zero_centered_gamma):
def onnx_rmsnorm_fwd(g, inp, weight, eps, sm_margin, zero_centered_gamma):
"""ONNX graph for rmsnorm_fwd"""
# pylint: disable=unused-argument
normalized_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inputs)
# Check dimensions
normalized_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inp)
if normalized_shape is None:
ndim = torch.onnx.symbolic_helper._get_tensor_rank(inputs)
ndim = torch.onnx.symbolic_helper._get_tensor_rank(inp)
assert ndim is not None
normalized_shape = list(range(0, ndim))
# Normalization axis = 0, so normalized_shape uses all dims except dim = 0
normalized_shape = normalized_shape[1:]
axis = -len(normalized_shape)
# Cast input tensors to FP32 if needed
dtype = get_TensorProtoDataType(inp)
if dtype != _type_utils.JitScalarType.FLOAT:
inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT)
if get_TensorProtoDataType(weight) != _type_utils.JitScalarType.FLOAT:
weight = g.op("Cast", weight, to_i=_C_onnx.TensorProtoDataType.FLOAT)
# Adjust zero-centered weights
if zero_centered_gamma:
inputs_dtype = inputs.type().dtype()
one = _ones_like(g, weight, inputs_dtype)
one = _ones_like(g, weight, torch.float32)
weight = g.op("Add", weight, one)
axis = -len(normalized_shape)
inputs_float = g.op("Cast", inputs, to_i=_C_onnx.TensorProtoDataType.FLOAT)
sum_square = g.op("ReduceSumSquare", inputs_float, axes_i=[axis])
shape = g.op("Shape", inputs_float, start_i=-1)
# Perform compute in FP32
sum_square = g.op("ReduceSumSquare", inp, axes_i=[axis])
shape = g.op("Shape", inp, start_i=-1)
shape_f = g.op("Cast", shape, to_i=_C_onnx.TensorProtoDataType.FLOAT)
mean_squared = g.op("Div", sum_square, shape_f)
eps_tensor = g.op("ConstantOfShape", shape, value_t=torch.tensor([eps], dtype=torch.float32))
rms_squared = g.op("Add", mean_squared, eps_tensor)
rms_eps = g.op("Sqrt", rms_squared)
normalized_input = g.op("Div", inputs_float, rms_eps)
result = g.op("Mul", weight, normalized_input)
result = g.op("Cast", result, to_i=get_TensorProtoDataType(inputs))
return result
normalized_input = g.op("Div", inp, rms_eps)
out = g.op("Mul", weight, normalized_input)
if dtype != _type_utils.JitScalarType.FLOAT:
out = g.op("Cast", out, to_i=dtype)
return out
register_custom_op_symbolic("tex_ts::cast_to_fp8_ts", onnx_cast_to_fp8, VER)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment