Unverified Commit c1b915ae authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Build system refactor for wheels (#877)



Cleanup
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent fc989613
......@@ -17,7 +17,7 @@ from utils import (
is_fused_attention_supported,
)
import transformer_engine_paddle as tex
from transformer_engine import transformer_engine_paddle as tex
from transformer_engine.paddle.cpp_extensions import (
cast_to_fp8,
cast_from_fp8,
......
......@@ -19,7 +19,7 @@ from transformer_engine.paddle.constants import (
FusedAttnBackend,
)
from transformer_engine.paddle.fp8 import FP8TensorMeta
import transformer_engine_paddle as tex # pylint: disable=wrong-import-order
from transformer_engine import transformer_engine_paddle as tex # pylint: disable=wrong-import-order
def create_fp8_meta(num_gemms=1, amax_history_len=10):
......
......@@ -6,7 +6,7 @@ import os, sys
import torch
import torch.distributed as dist
from transformer_engine.pytorch.attention import DotProductAttention
import transformer_engine_extensions as tex
import transformer_engine_torch as tex
from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn
dtypes={'fp16' : torch.float16, 'bf16' : torch.bfloat16}
......
......@@ -38,8 +38,8 @@ from transformer_engine.pytorch.utils import (
scaled_init_method_normal,
is_bf16_compatible,
)
import transformer_engine_extensions as tex
from transformer_engine_extensions import NVTE_Fused_Attn_Backend
import transformer_engine_torch as tex
from transformer_engine_torch import NVTE_Fused_Attn_Backend
# Only run FP8 tests on H100
fp8_available, reason_for_no_fp8 = fp8.FP8GlobalStateManager.is_fp8_available()
......
......@@ -13,7 +13,7 @@ import transformer_engine.common.recipe
import transformer_engine.pytorch as te
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine_extensions as tex
import transformer_engine_torch as tex
# PyTorch tensor dtypes
_dtypes: List[torch.dtype] = [torch.float32, torch.float16, torch.bfloat16]
......
......@@ -6,7 +6,7 @@ import pytest
import torch
import transformer_engine.pytorch as te
import transformer_engine_extensions as tex
import transformer_engine_torch as tex
from transformer_engine.pytorch.optimizers import MultiTensorApply
input_size_pairs = [
......
......@@ -31,7 +31,7 @@ from torch import nn as nn
from typing import Optional, Union, Tuple, List
import transformer_engine.pytorch as te
from transformer_engine.common import recipe
import transformer_engine_extensions as tex
import transformer_engine_torch as tex
from transformer_engine.pytorch.cpp_extensions import gemm, fp8_gemm, gelu, cast_to_fp8, cast_from_fp8
from transformer_engine.pytorch.module.base import get_workspace
import transformer_engine.pytorch.cpp_extensions as texcpp
......
......@@ -9,7 +9,7 @@ import torch
import transformer_engine.common.recipe
import transformer_engine.pytorch as te
import transformer_engine_extensions as tex
import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import (
FP8GlobalStateManager,
_amax_and_scale_update,
......
......@@ -29,7 +29,7 @@ from transformer_engine.pytorch import (
get_cpu_offload_context,
)
from transformer_engine.common import recipe
import transformer_engine_extensions as tex
import transformer_engine_torch as tex
from transformer_engine.pytorch.cpp_extensions import gemm, fp8_gemm, gelu, cast_to_fp8, cast_from_fp8
from transformer_engine.pytorch.module.base import get_workspace
from test_onnx_export import create_meta
......
......@@ -18,7 +18,7 @@ from typing import Iterable, Union
import pytest
import torch
import transformer_engine.pytorch as te
import transformer_engine_extensions as tex
import transformer_engine_torch as tex
from transformer_engine.pytorch.cpp_extensions import fp8_gemm, cast_to_fp8
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.module.base import get_workspace
......
......@@ -41,11 +41,3 @@ if(NVTE_WITH_USERBUFFERS)
message(STATUS "userbuffers support enabled")
add_subdirectory(pytorch/csrc/userbuffers)
endif()
option(ENABLE_JAX "Enable JAX in the building workflow." OFF)
message(STATUS "JAX support: ${ENABLE_JAX}")
if(ENABLE_JAX)
find_package(pybind11 CONFIG REQUIRED)
add_subdirectory(jax)
endif()
......@@ -3,15 +3,35 @@
# See LICENSE for license information.
"""Top level package"""
from ._version import __version__
from . import common
# pylint: disable=unused-import
from importlib import metadata
import transformer_engine.common
try:
from . import pytorch
except ImportError as e:
except (ImportError, StopIteration) as e:
pass
try:
from . import jax
except ImportError as e:
except (ImportError, StopIteration) as e:
pass
try:
from . import paddle
except (ImportError, StopIteration) as e:
pass
try:
import transformer_engine_jax
except ImportError:
pass
try:
import transformer_engine_paddle
except ImportError:
pass
__version__ = str(metadata.version("transformer_engine"))
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Version information"""
import sys
from packaging.version import Version
if sys.version_info >= (3, 8):
from importlib import metadata
else:
import importlib_metadata as metadata
def _version_str() -> str:
"""Transformer Engine version string"""
# Try getting version from package metadata
version_str = None
try:
version_str = metadata.version("transformer_engine")
except:
pass
if version_str:
return version_str
# Try getting version from Git root directory
try:
from te_version import te_version
version_str = te_version()
except:
pass
if version_str:
return version_str
# Could not deduce version
return "0.dev0+unknown"
# Transformer Engine version
__version__: Version = Version(_version_str())
......@@ -3,25 +3,21 @@
# See LICENSE for license information.
"""FW agnostic user-end APIs"""
import ctypes
import os
import platform
import subprocess
import sys
from pathlib import Path
import transformer_engine
def get_te_path():
"""Find Transformer Engine install path using pip"""
return Path(transformer_engine.__path__[0]).parent
command = [sys.executable, "-m", "pip", "show", "transformer_engine"]
result = subprocess.run(command, capture_output=True, check=True, text=True)
result = result.stdout.replace("\n", ":").split(":")
return result[result.index("Location") + 1].strip()
def _load_library():
"""Load shared library with Transformer Engine C extensions"""
def _get_sys_extension():
system = platform.system()
if system == "Linux":
extension = "so"
......@@ -31,33 +27,32 @@ def _load_library():
extension = "dll"
else:
raise RuntimeError(f"Unsupported operating system ({system})")
lib_name = "libtransformer_engine." + extension
dll_path = get_te_path()
dll_path = os.path.join(dll_path, lib_name)
return ctypes.CDLL(dll_path, mode=ctypes.RTLD_GLOBAL)
return extension
def _load_library():
"""Load shared library with Transformer Engine C extensions"""
so_path = get_te_path() / "transformer_engine" / f"libtransformer_engine.{_get_sys_extension()}"
if not so_path.exists():
so_path = get_te_path() / f"libtransformer_engine.{_get_sys_extension()}"
assert so_path.exists(), f"Could not find libtransformer_engine.{_get_sys_extension()}"
return ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL)
def _load_userbuffers():
"""Load shared library with userbuffers"""
system = platform.system()
if system == "Linux":
extension = "so"
elif system == "Darwin":
extension = "dylib"
elif system == "Windows":
extension = "dll"
else:
raise RuntimeError(f"Unsupported operating system ({system})")
lib_name = "libtransformer_engine_userbuffers." + extension
dll_path = get_te_path()
dll_path = os.path.join(dll_path, lib_name)
so_dir = get_te_path() / "transformer_engine"
so_file = so_dir / f"libtransformer_engine_userbuffers.{_get_sys_extension()}"
if os.path.exists(dll_path):
return ctypes.CDLL(dll_path, mode=ctypes.RTLD_GLOBAL)
if so_file.exists():
return ctypes.CDLL(so_file, mode=ctypes.RTLD_GLOBAL)
return None
_TE_LIB_CTYPES = _load_library()
_UB_LIB_CTYPES = _load_userbuffers()
if "NVTE_PROJECT_BUILDING" not in os.environ:
_TE_LIB_CTYPES = _load_library()
_UB_LIB_CTYPES = _load_userbuffers()
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
pybind11_add_module(
transformer_engine_jax
${CMAKE_CURRENT_SOURCE_DIR}/csrc/extensions.cpp
${CMAKE_CURRENT_SOURCE_DIR}/csrc/modules.cpp
${CMAKE_CURRENT_SOURCE_DIR}/csrc/utils.cu
)
target_link_libraries(transformer_engine_jax PRIVATE CUDA::cudart CUDA::cublas CUDA::cublasLt transformer_engine)
install(TARGETS transformer_engine_jax DESTINATION .)
......@@ -3,6 +3,28 @@
# See LICENSE for license information.
"""Transformer Engine bindings for JAX"""
# pylint: disable=wrong-import-position,wrong-import-order
import ctypes
from transformer_engine.common import get_te_path
from transformer_engine.common import _get_sys_extension
def _load_library():
"""Load shared library with Transformer Engine C extensions"""
extension = _get_sys_extension()
try:
so_dir = get_te_path() / "transformer_engine"
so_path = next(so_dir.glob(f"transformer_engine_jax.*.{extension}"))
except StopIteration:
so_dir = get_te_path()
so_path = next(so_dir.glob(f"transformer_engine_jax.*.{extension}"))
return ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL)
_TE_JAX_LIB_CTYPES = _load_library()
from . import flax
from .fp8 import fp8_autocast, update_collections, update_fp8_metas, get_delayed_scaling
from .fp8 import NVTE_FP8_COLLECTION_NAME
......
......@@ -21,13 +21,13 @@ from jax.sharding import PartitionSpec, NamedSharding
from jax._src.interpreters import batching
from jax._src import dispatch
import transformer_engine_jax
from transformer_engine_jax import DType as TEDType
from transformer_engine_jax import NVTE_Bias_Type
from transformer_engine_jax import NVTE_Mask_Type
from transformer_engine_jax import NVTE_QKV_Layout
from transformer_engine_jax import NVTE_Fused_Attn_Backend
from transformer_engine_jax import NVTE_Activation_Type
from transformer_engine import transformer_engine_jax
from transformer_engine.transformer_engine_jax import DType as TEDType
from transformer_engine.transformer_engine_jax import NVTE_Bias_Type
from transformer_engine.transformer_engine_jax import NVTE_Mask_Type
from transformer_engine.transformer_engine_jax import NVTE_QKV_Layout
from transformer_engine.transformer_engine_jax import NVTE_Fused_Attn_Backend
from transformer_engine.transformer_engine_jax import NVTE_Activation_Type
from .sharding import all_reduce_max_along_all_axes_except_PP
from .sharding import all_reduce_sum_along_dp_fsdp
......
......@@ -11,8 +11,8 @@
#include "common/include/transformer_engine/fused_attn.h"
#include "common/include/transformer_engine/activation.h"
#include "common/include/transformer_engine/transformer_engine.h"
#include "jax/csrc/modules.h"
#include "jax/csrc/utils.h"
#include "modules.h"
#include "utils.h"
namespace transformer_engine {
namespace jax {
......
......@@ -4,7 +4,7 @@
* See LICENSE for license information.
************************************************************************/
#include "jax/csrc/modules.h"
#include "modules.h"
#include <cublasLt.h>
#include <cublas_v2.h>
......@@ -169,7 +169,9 @@ void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size
auto *input_cast = buffers[4];
auto *input_cast_trans = buffers[5];
float *amax_out = reinterpret_cast<float *>(buffers[6]);
assert(amax == amax_out);
NVTE_CHECK(
amax == amax_out,
"Internal TE/JAX error: amax_out should be bound to amax in the JAX primitive.");
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
if (!use_fp8(desc.out_dtype)) {
......@@ -247,7 +249,7 @@ void ActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaqu
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);;
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);
ActLuImpl(input, m, n, desc.in_dtype, desc.out_dtype, nullptr, stream,
nullptr, nullptr, output, act_enum);
......@@ -260,7 +262,9 @@ void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t op
float *scale_inv = reinterpret_cast<float *>(buffers[3]);
auto *output = buffers[4];
float *amax_out = reinterpret_cast<float *>(buffers[5]);
assert(amax == amax_out);
NVTE_CHECK(
amax == amax_out,
"Internal TE/JAX error: amax_out should be bound to amax in the JAX primitive.");
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
if (!use_fp8(desc.out_dtype)) {
......@@ -270,7 +274,7 @@ void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t op
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);;
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);
ActLuImpl(input, m, n, desc.in_dtype, desc.out_dtype, scale, stream,
scale_inv, amax_out, output, act_enum);
......@@ -284,7 +288,7 @@ void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaq
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);;
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);
auto act_len = get_activation_len(act_enum);
auto input_shape = std::vector<size_t>{m, n};
......@@ -381,7 +385,10 @@ void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *o
void *workspace_ptr = buffers[9];
const auto &desc = *UnpackOpaque<CustomCallCommonWkDescriptor>(opaque, opaque_len);
assert(amax == amax_out);
NVTE_CHECK(
amax == amax_out,
"Internal TE/JAX error: amax_out should be bound to amax in the JAX primitive.");
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
......@@ -389,7 +396,7 @@ void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *o
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);;
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);
auto input_shape = std::vector<size_t>{m, n};
auto act_input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n};
......@@ -448,9 +455,11 @@ void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *o
auto *output = buffers[5];
auto *output_trans = buffers[6];
float *amax_out = reinterpret_cast<float *>(buffers[7]);
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
assert(amax == amax_out);
NVTE_CHECK(
amax == amax_out,
"Internal TE/JAX error: amax_out should be bound to amax in the JAX primitive.");
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
......@@ -458,7 +467,7 @@ void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *o
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);;
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);
auto input_shape = desc.shape.to_vector();
auto act_input_shape = std::vector<size_t>{m, n * 2};
auto output_shape = std::vector<size_t>{m, n * 2};
......@@ -538,7 +547,10 @@ void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
void *workspace_ptr = buffers[8];
const auto &desc = *UnpackOpaque<CustomCallCommonWkDescriptor>(opaque, opaque_len);
assert(amax == amax_out);
NVTE_CHECK(
amax == amax_out,
"Internal TE/JAX error: amax_out should be bound to amax in the JAX primitive.");
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
......@@ -773,7 +785,9 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque
auto *amax_out = buffers[9];
auto *workspace = buffers[10];
auto *barrier = buffers[11];
assert(amax_out == amax);
NVTE_CHECK(
amax == amax_out,
"Internal TE/JAX error: amax_out should be bound to amax in the JAX primitive.");
const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
auto batch_size = desc.batch_size;
......@@ -786,7 +800,6 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque
auto barrier_dtype = desc.barrier_dtype;
auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
auto out_dtype = DType::kFloat8E4M3;
......@@ -822,7 +835,6 @@ void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, s
auto eps = desc.eps;
auto out_dtype = in_dtype;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
......@@ -847,7 +859,6 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto dbeta_part_dtype = desc.dbeta_part_dtype;
auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
auto *ograd = buffers[0];
auto *mu = buffers[1];
......@@ -880,7 +891,9 @@ void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
auto *amax_out = buffers[7];
auto *workspace = buffers[8];
auto *barrier = buffers[9];
assert(amax_out == amax);
NVTE_CHECK(
amax == amax_out,
"Internal TE/JAX error: amax_out should be bound to amax in the JAX primitive.");
void *bias = nullptr;
void *mu = nullptr;
......@@ -896,7 +909,6 @@ void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
auto barrier_dtype = desc.barrier_dtype;
auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
auto out_dtype = DType::kFloat8E4M3;
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
......@@ -930,7 +942,6 @@ void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, siz
auto barrier_dtype = desc.barrier_dtype;
auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
auto out_dtype = in_dtype;
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
......@@ -985,7 +996,9 @@ void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t op
auto *scale_inv = reinterpret_cast<float *>(buffers[3]);
auto *output = buffers[4];
auto *amax_out = reinterpret_cast<float *>(buffers[5]);
assert(amax == amax_out);
NVTE_CHECK(
amax == amax_out,
"Internal TE/JAX error: amax_out should be bound to amax in the JAX primitive.");
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto shape = desc.shape.to_vector();
......
......@@ -13,9 +13,12 @@ import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from flax.linen import fp8_ops
from transformer_engine_jax import DType
from transformer_engine_jax import get_cublasLt_version
from transformer_engine_jax import get_cuda_version, get_device_compute_capability
from transformer_engine.transformer_engine_jax import DType
from transformer_engine.transformer_engine_jax import get_cublasLt_version
from transformer_engine.transformer_engine_jax import (
get_cuda_version,
get_device_compute_capability,
)
from transformer_engine.common.recipe import DelayedScaling, Format
from transformer_engine.jax.sharding import global_shard_guard
from transformer_engine.jax.sharding import MeshResource
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment