Unverified Commit c1b915ae authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Build system refactor for wheels (#877)



Cleanup
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent fc989613
...@@ -17,7 +17,7 @@ from utils import ( ...@@ -17,7 +17,7 @@ from utils import (
is_fused_attention_supported, is_fused_attention_supported,
) )
import transformer_engine_paddle as tex from transformer_engine import transformer_engine_paddle as tex
from transformer_engine.paddle.cpp_extensions import ( from transformer_engine.paddle.cpp_extensions import (
cast_to_fp8, cast_to_fp8,
cast_from_fp8, cast_from_fp8,
......
...@@ -19,7 +19,7 @@ from transformer_engine.paddle.constants import ( ...@@ -19,7 +19,7 @@ from transformer_engine.paddle.constants import (
FusedAttnBackend, FusedAttnBackend,
) )
from transformer_engine.paddle.fp8 import FP8TensorMeta from transformer_engine.paddle.fp8 import FP8TensorMeta
import transformer_engine_paddle as tex # pylint: disable=wrong-import-order from transformer_engine import transformer_engine_paddle as tex # pylint: disable=wrong-import-order
def create_fp8_meta(num_gemms=1, amax_history_len=10): def create_fp8_meta(num_gemms=1, amax_history_len=10):
......
...@@ -6,7 +6,7 @@ import os, sys ...@@ -6,7 +6,7 @@ import os, sys
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from transformer_engine.pytorch.attention import DotProductAttention from transformer_engine.pytorch.attention import DotProductAttention
import transformer_engine_extensions as tex import transformer_engine_torch as tex
from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn
dtypes={'fp16' : torch.float16, 'bf16' : torch.bfloat16} dtypes={'fp16' : torch.float16, 'bf16' : torch.bfloat16}
......
...@@ -38,8 +38,8 @@ from transformer_engine.pytorch.utils import ( ...@@ -38,8 +38,8 @@ from transformer_engine.pytorch.utils import (
scaled_init_method_normal, scaled_init_method_normal,
is_bf16_compatible, is_bf16_compatible,
) )
import transformer_engine_extensions as tex import transformer_engine_torch as tex
from transformer_engine_extensions import NVTE_Fused_Attn_Backend from transformer_engine_torch import NVTE_Fused_Attn_Backend
# Only run FP8 tests on H100 # Only run FP8 tests on H100
fp8_available, reason_for_no_fp8 = fp8.FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = fp8.FP8GlobalStateManager.is_fp8_available()
......
...@@ -13,7 +13,7 @@ import transformer_engine.common.recipe ...@@ -13,7 +13,7 @@ import transformer_engine.common.recipe
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import transformer_engine_extensions as tex import transformer_engine_torch as tex
# PyTorch tensor dtypes # PyTorch tensor dtypes
_dtypes: List[torch.dtype] = [torch.float32, torch.float16, torch.bfloat16] _dtypes: List[torch.dtype] = [torch.float32, torch.float16, torch.bfloat16]
......
...@@ -6,7 +6,7 @@ import pytest ...@@ -6,7 +6,7 @@ import pytest
import torch import torch
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
import transformer_engine_extensions as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.optimizers import MultiTensorApply from transformer_engine.pytorch.optimizers import MultiTensorApply
input_size_pairs = [ input_size_pairs = [
......
...@@ -31,7 +31,7 @@ from torch import nn as nn ...@@ -31,7 +31,7 @@ from torch import nn as nn
from typing import Optional, Union, Tuple, List from typing import Optional, Union, Tuple, List
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
from transformer_engine.common import recipe from transformer_engine.common import recipe
import transformer_engine_extensions as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.cpp_extensions import gemm, fp8_gemm, gelu, cast_to_fp8, cast_from_fp8 from transformer_engine.pytorch.cpp_extensions import gemm, fp8_gemm, gelu, cast_to_fp8, cast_from_fp8
from transformer_engine.pytorch.module.base import get_workspace from transformer_engine.pytorch.module.base import get_workspace
import transformer_engine.pytorch.cpp_extensions as texcpp import transformer_engine.pytorch.cpp_extensions as texcpp
......
...@@ -9,7 +9,7 @@ import torch ...@@ -9,7 +9,7 @@ import torch
import transformer_engine.common.recipe import transformer_engine.common.recipe
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
import transformer_engine_extensions as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import ( from transformer_engine.pytorch.fp8 import (
FP8GlobalStateManager, FP8GlobalStateManager,
_amax_and_scale_update, _amax_and_scale_update,
......
...@@ -29,7 +29,7 @@ from transformer_engine.pytorch import ( ...@@ -29,7 +29,7 @@ from transformer_engine.pytorch import (
get_cpu_offload_context, get_cpu_offload_context,
) )
from transformer_engine.common import recipe from transformer_engine.common import recipe
import transformer_engine_extensions as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.cpp_extensions import gemm, fp8_gemm, gelu, cast_to_fp8, cast_from_fp8 from transformer_engine.pytorch.cpp_extensions import gemm, fp8_gemm, gelu, cast_to_fp8, cast_from_fp8
from transformer_engine.pytorch.module.base import get_workspace from transformer_engine.pytorch.module.base import get_workspace
from test_onnx_export import create_meta from test_onnx_export import create_meta
......
...@@ -18,7 +18,7 @@ from typing import Iterable, Union ...@@ -18,7 +18,7 @@ from typing import Iterable, Union
import pytest import pytest
import torch import torch
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
import transformer_engine_extensions as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.cpp_extensions import fp8_gemm, cast_to_fp8 from transformer_engine.pytorch.cpp_extensions import fp8_gemm, cast_to_fp8
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.module.base import get_workspace from transformer_engine.pytorch.module.base import get_workspace
......
...@@ -41,11 +41,3 @@ if(NVTE_WITH_USERBUFFERS) ...@@ -41,11 +41,3 @@ if(NVTE_WITH_USERBUFFERS)
message(STATUS "userbuffers support enabled") message(STATUS "userbuffers support enabled")
add_subdirectory(pytorch/csrc/userbuffers) add_subdirectory(pytorch/csrc/userbuffers)
endif() endif()
option(ENABLE_JAX "Enable JAX in the building workflow." OFF)
message(STATUS "JAX support: ${ENABLE_JAX}")
if(ENABLE_JAX)
find_package(pybind11 CONFIG REQUIRED)
add_subdirectory(jax)
endif()
...@@ -3,15 +3,35 @@ ...@@ -3,15 +3,35 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Top level package""" """Top level package"""
from ._version import __version__
from . import common # pylint: disable=unused-import
from importlib import metadata
import transformer_engine.common
try: try:
from . import pytorch from . import pytorch
except ImportError as e: except (ImportError, StopIteration) as e:
pass pass
try: try:
from . import jax from . import jax
except ImportError as e: except (ImportError, StopIteration) as e:
pass
try:
from . import paddle
except (ImportError, StopIteration) as e:
pass
try:
import transformer_engine_jax
except ImportError:
pass
try:
import transformer_engine_paddle
except ImportError:
pass pass
__version__ = str(metadata.version("transformer_engine"))
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Version information"""
import sys
from packaging.version import Version
if sys.version_info >= (3, 8):
from importlib import metadata
else:
import importlib_metadata as metadata
def _version_str() -> str:
"""Transformer Engine version string"""
# Try getting version from package metadata
version_str = None
try:
version_str = metadata.version("transformer_engine")
except:
pass
if version_str:
return version_str
# Try getting version from Git root directory
try:
from te_version import te_version
version_str = te_version()
except:
pass
if version_str:
return version_str
# Could not deduce version
return "0.dev0+unknown"
# Transformer Engine version
__version__: Version = Version(_version_str())
...@@ -3,25 +3,21 @@ ...@@ -3,25 +3,21 @@
# See LICENSE for license information. # See LICENSE for license information.
"""FW agnostic user-end APIs""" """FW agnostic user-end APIs"""
import ctypes import ctypes
import os import os
import platform import platform
import subprocess from pathlib import Path
import sys
import transformer_engine
def get_te_path(): def get_te_path():
"""Find Transformer Engine install path using pip""" """Find Transformer Engine install path using pip"""
return Path(transformer_engine.__path__[0]).parent
command = [sys.executable, "-m", "pip", "show", "transformer_engine"]
result = subprocess.run(command, capture_output=True, check=True, text=True)
result = result.stdout.replace("\n", ":").split(":")
return result[result.index("Location") + 1].strip()
def _load_library():
"""Load shared library with Transformer Engine C extensions"""
def _get_sys_extension():
system = platform.system() system = platform.system()
if system == "Linux": if system == "Linux":
extension = "so" extension = "so"
...@@ -31,33 +27,32 @@ def _load_library(): ...@@ -31,33 +27,32 @@ def _load_library():
extension = "dll" extension = "dll"
else: else:
raise RuntimeError(f"Unsupported operating system ({system})") raise RuntimeError(f"Unsupported operating system ({system})")
lib_name = "libtransformer_engine." + extension
dll_path = get_te_path()
dll_path = os.path.join(dll_path, lib_name)
return ctypes.CDLL(dll_path, mode=ctypes.RTLD_GLOBAL) return extension
def _load_library():
"""Load shared library with Transformer Engine C extensions"""
so_path = get_te_path() / "transformer_engine" / f"libtransformer_engine.{_get_sys_extension()}"
if not so_path.exists():
so_path = get_te_path() / f"libtransformer_engine.{_get_sys_extension()}"
assert so_path.exists(), f"Could not find libtransformer_engine.{_get_sys_extension()}"
return ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL)
def _load_userbuffers(): def _load_userbuffers():
"""Load shared library with userbuffers""" """Load shared library with userbuffers"""
system = platform.system() so_dir = get_te_path() / "transformer_engine"
if system == "Linux": so_file = so_dir / f"libtransformer_engine_userbuffers.{_get_sys_extension()}"
extension = "so"
elif system == "Darwin":
extension = "dylib"
elif system == "Windows":
extension = "dll"
else:
raise RuntimeError(f"Unsupported operating system ({system})")
lib_name = "libtransformer_engine_userbuffers." + extension
dll_path = get_te_path()
dll_path = os.path.join(dll_path, lib_name)
if os.path.exists(dll_path): if so_file.exists():
return ctypes.CDLL(dll_path, mode=ctypes.RTLD_GLOBAL) return ctypes.CDLL(so_file, mode=ctypes.RTLD_GLOBAL)
return None return None
_TE_LIB_CTYPES = _load_library() if "NVTE_PROJECT_BUILDING" not in os.environ:
_UB_LIB_CTYPES = _load_userbuffers() _TE_LIB_CTYPES = _load_library()
_UB_LIB_CTYPES = _load_userbuffers()
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
pybind11_add_module(
transformer_engine_jax
${CMAKE_CURRENT_SOURCE_DIR}/csrc/extensions.cpp
${CMAKE_CURRENT_SOURCE_DIR}/csrc/modules.cpp
${CMAKE_CURRENT_SOURCE_DIR}/csrc/utils.cu
)
target_link_libraries(transformer_engine_jax PRIVATE CUDA::cudart CUDA::cublas CUDA::cublasLt transformer_engine)
install(TARGETS transformer_engine_jax DESTINATION .)
...@@ -3,6 +3,28 @@ ...@@ -3,6 +3,28 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Transformer Engine bindings for JAX""" """Transformer Engine bindings for JAX"""
# pylint: disable=wrong-import-position,wrong-import-order
import ctypes
from transformer_engine.common import get_te_path
from transformer_engine.common import _get_sys_extension
def _load_library():
"""Load shared library with Transformer Engine C extensions"""
extension = _get_sys_extension()
try:
so_dir = get_te_path() / "transformer_engine"
so_path = next(so_dir.glob(f"transformer_engine_jax.*.{extension}"))
except StopIteration:
so_dir = get_te_path()
so_path = next(so_dir.glob(f"transformer_engine_jax.*.{extension}"))
return ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL)
_TE_JAX_LIB_CTYPES = _load_library()
from . import flax from . import flax
from .fp8 import fp8_autocast, update_collections, update_fp8_metas, get_delayed_scaling from .fp8 import fp8_autocast, update_collections, update_fp8_metas, get_delayed_scaling
from .fp8 import NVTE_FP8_COLLECTION_NAME from .fp8 import NVTE_FP8_COLLECTION_NAME
......
...@@ -21,13 +21,13 @@ from jax.sharding import PartitionSpec, NamedSharding ...@@ -21,13 +21,13 @@ from jax.sharding import PartitionSpec, NamedSharding
from jax._src.interpreters import batching from jax._src.interpreters import batching
from jax._src import dispatch from jax._src import dispatch
import transformer_engine_jax from transformer_engine import transformer_engine_jax
from transformer_engine_jax import DType as TEDType from transformer_engine.transformer_engine_jax import DType as TEDType
from transformer_engine_jax import NVTE_Bias_Type from transformer_engine.transformer_engine_jax import NVTE_Bias_Type
from transformer_engine_jax import NVTE_Mask_Type from transformer_engine.transformer_engine_jax import NVTE_Mask_Type
from transformer_engine_jax import NVTE_QKV_Layout from transformer_engine.transformer_engine_jax import NVTE_QKV_Layout
from transformer_engine_jax import NVTE_Fused_Attn_Backend from transformer_engine.transformer_engine_jax import NVTE_Fused_Attn_Backend
from transformer_engine_jax import NVTE_Activation_Type from transformer_engine.transformer_engine_jax import NVTE_Activation_Type
from .sharding import all_reduce_max_along_all_axes_except_PP from .sharding import all_reduce_max_along_all_axes_except_PP
from .sharding import all_reduce_sum_along_dp_fsdp from .sharding import all_reduce_sum_along_dp_fsdp
......
...@@ -11,8 +11,8 @@ ...@@ -11,8 +11,8 @@
#include "common/include/transformer_engine/fused_attn.h" #include "common/include/transformer_engine/fused_attn.h"
#include "common/include/transformer_engine/activation.h" #include "common/include/transformer_engine/activation.h"
#include "common/include/transformer_engine/transformer_engine.h" #include "common/include/transformer_engine/transformer_engine.h"
#include "jax/csrc/modules.h" #include "modules.h"
#include "jax/csrc/utils.h" #include "utils.h"
namespace transformer_engine { namespace transformer_engine {
namespace jax { namespace jax {
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "jax/csrc/modules.h" #include "modules.h"
#include <cublasLt.h> #include <cublasLt.h>
#include <cublas_v2.h> #include <cublas_v2.h>
...@@ -169,7 +169,9 @@ void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size ...@@ -169,7 +169,9 @@ void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size
auto *input_cast = buffers[4]; auto *input_cast = buffers[4];
auto *input_cast_trans = buffers[5]; auto *input_cast_trans = buffers[5];
float *amax_out = reinterpret_cast<float *>(buffers[6]); float *amax_out = reinterpret_cast<float *>(buffers[6]);
assert(amax == amax_out); NVTE_CHECK(
amax == amax_out,
"Internal TE/JAX error: amax_out should be bound to amax in the JAX primitive.");
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len); const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
if (!use_fp8(desc.out_dtype)) { if (!use_fp8(desc.out_dtype)) {
...@@ -247,7 +249,7 @@ void ActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaqu ...@@ -247,7 +249,7 @@ void ActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaqu
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len); const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto m = desc.shape.dims[0]; auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1]; auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);; auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);
ActLuImpl(input, m, n, desc.in_dtype, desc.out_dtype, nullptr, stream, ActLuImpl(input, m, n, desc.in_dtype, desc.out_dtype, nullptr, stream,
nullptr, nullptr, output, act_enum); nullptr, nullptr, output, act_enum);
...@@ -260,7 +262,9 @@ void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t op ...@@ -260,7 +262,9 @@ void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t op
float *scale_inv = reinterpret_cast<float *>(buffers[3]); float *scale_inv = reinterpret_cast<float *>(buffers[3]);
auto *output = buffers[4]; auto *output = buffers[4];
float *amax_out = reinterpret_cast<float *>(buffers[5]); float *amax_out = reinterpret_cast<float *>(buffers[5]);
assert(amax == amax_out); NVTE_CHECK(
amax == amax_out,
"Internal TE/JAX error: amax_out should be bound to amax in the JAX primitive.");
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len); const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
if (!use_fp8(desc.out_dtype)) { if (!use_fp8(desc.out_dtype)) {
...@@ -270,7 +274,7 @@ void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t op ...@@ -270,7 +274,7 @@ void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t op
} }
auto m = desc.shape.dims[0]; auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1]; auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);; auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);
ActLuImpl(input, m, n, desc.in_dtype, desc.out_dtype, scale, stream, ActLuImpl(input, m, n, desc.in_dtype, desc.out_dtype, scale, stream,
scale_inv, amax_out, output, act_enum); scale_inv, amax_out, output, act_enum);
...@@ -284,7 +288,7 @@ void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaq ...@@ -284,7 +288,7 @@ void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaq
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len); const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto m = desc.shape.dims[0]; auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1]; auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);; auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);
auto act_len = get_activation_len(act_enum); auto act_len = get_activation_len(act_enum);
auto input_shape = std::vector<size_t>{m, n}; auto input_shape = std::vector<size_t>{m, n};
...@@ -381,7 +385,10 @@ void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *o ...@@ -381,7 +385,10 @@ void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *o
void *workspace_ptr = buffers[9]; void *workspace_ptr = buffers[9];
const auto &desc = *UnpackOpaque<CustomCallCommonWkDescriptor>(opaque, opaque_len); const auto &desc = *UnpackOpaque<CustomCallCommonWkDescriptor>(opaque, opaque_len);
assert(amax == amax_out); NVTE_CHECK(
amax == amax_out,
"Internal TE/JAX error: amax_out should be bound to amax in the JAX primitive.");
if (!use_fp8(desc.out_dtype)) { if (!use_fp8(desc.out_dtype)) {
scale = nullptr; scale = nullptr;
scale_inv = nullptr; scale_inv = nullptr;
...@@ -389,7 +396,7 @@ void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *o ...@@ -389,7 +396,7 @@ void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *o
} }
auto m = desc.shape.dims[0]; auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1]; auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);; auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);
auto input_shape = std::vector<size_t>{m, n}; auto input_shape = std::vector<size_t>{m, n};
auto act_input_shape = std::vector<size_t>{m, n}; auto act_input_shape = std::vector<size_t>{m, n};
auto output_shape = std::vector<size_t>{m, n}; auto output_shape = std::vector<size_t>{m, n};
...@@ -448,9 +455,11 @@ void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *o ...@@ -448,9 +455,11 @@ void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *o
auto *output = buffers[5]; auto *output = buffers[5];
auto *output_trans = buffers[6]; auto *output_trans = buffers[6];
float *amax_out = reinterpret_cast<float *>(buffers[7]); float *amax_out = reinterpret_cast<float *>(buffers[7]);
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len); const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
assert(amax == amax_out); NVTE_CHECK(
amax == amax_out,
"Internal TE/JAX error: amax_out should be bound to amax in the JAX primitive.");
if (!use_fp8(desc.out_dtype)) { if (!use_fp8(desc.out_dtype)) {
scale = nullptr; scale = nullptr;
scale_inv = nullptr; scale_inv = nullptr;
...@@ -458,7 +467,7 @@ void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *o ...@@ -458,7 +467,7 @@ void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *o
} }
auto m = desc.shape.dims[0]; auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1]; auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);; auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);
auto input_shape = desc.shape.to_vector(); auto input_shape = desc.shape.to_vector();
auto act_input_shape = std::vector<size_t>{m, n * 2}; auto act_input_shape = std::vector<size_t>{m, n * 2};
auto output_shape = std::vector<size_t>{m, n * 2}; auto output_shape = std::vector<size_t>{m, n * 2};
...@@ -538,7 +547,10 @@ void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -538,7 +547,10 @@ void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
void *workspace_ptr = buffers[8]; void *workspace_ptr = buffers[8];
const auto &desc = *UnpackOpaque<CustomCallCommonWkDescriptor>(opaque, opaque_len); const auto &desc = *UnpackOpaque<CustomCallCommonWkDescriptor>(opaque, opaque_len);
assert(amax == amax_out); NVTE_CHECK(
amax == amax_out,
"Internal TE/JAX error: amax_out should be bound to amax in the JAX primitive.");
if (!use_fp8(desc.out_dtype)) { if (!use_fp8(desc.out_dtype)) {
scale = nullptr; scale = nullptr;
scale_inv = nullptr; scale_inv = nullptr;
...@@ -773,7 +785,9 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque ...@@ -773,7 +785,9 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque
auto *amax_out = buffers[9]; auto *amax_out = buffers[9];
auto *workspace = buffers[10]; auto *workspace = buffers[10];
auto *barrier = buffers[11]; auto *barrier = buffers[11];
assert(amax_out == amax); NVTE_CHECK(
amax == amax_out,
"Internal TE/JAX error: amax_out should be bound to amax in the JAX primitive.");
const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len); const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
auto batch_size = desc.batch_size; auto batch_size = desc.batch_size;
...@@ -786,7 +800,6 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque ...@@ -786,7 +800,6 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque
auto barrier_dtype = desc.barrier_dtype; auto barrier_dtype = desc.barrier_dtype;
auto eps = desc.eps; auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma; auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
auto out_dtype = DType::kFloat8E4M3; auto out_dtype = DType::kFloat8E4M3;
...@@ -822,7 +835,6 @@ void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -822,7 +835,6 @@ void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, s
auto eps = desc.eps; auto eps = desc.eps;
auto out_dtype = in_dtype; auto out_dtype = in_dtype;
auto zero_centered_gamma = desc.zero_centered_gamma; auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma, LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
...@@ -847,7 +859,6 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -847,7 +859,6 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto dbeta_part_dtype = desc.dbeta_part_dtype; auto dbeta_part_dtype = desc.dbeta_part_dtype;
auto eps = desc.eps; auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma; auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
auto *ograd = buffers[0]; auto *ograd = buffers[0];
auto *mu = buffers[1]; auto *mu = buffers[1];
...@@ -880,7 +891,9 @@ void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -880,7 +891,9 @@ void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
auto *amax_out = buffers[7]; auto *amax_out = buffers[7];
auto *workspace = buffers[8]; auto *workspace = buffers[8];
auto *barrier = buffers[9]; auto *barrier = buffers[9];
assert(amax_out == amax); NVTE_CHECK(
amax == amax_out,
"Internal TE/JAX error: amax_out should be bound to amax in the JAX primitive.");
void *bias = nullptr; void *bias = nullptr;
void *mu = nullptr; void *mu = nullptr;
...@@ -896,7 +909,6 @@ void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -896,7 +909,6 @@ void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
auto barrier_dtype = desc.barrier_dtype; auto barrier_dtype = desc.barrier_dtype;
auto eps = desc.eps; auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma; auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
auto out_dtype = DType::kFloat8E4M3; auto out_dtype = DType::kFloat8E4M3;
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma, LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
...@@ -930,7 +942,6 @@ void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, siz ...@@ -930,7 +942,6 @@ void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, siz
auto barrier_dtype = desc.barrier_dtype; auto barrier_dtype = desc.barrier_dtype;
auto eps = desc.eps; auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma; auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
auto out_dtype = in_dtype; auto out_dtype = in_dtype;
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma, LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
...@@ -985,7 +996,9 @@ void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t op ...@@ -985,7 +996,9 @@ void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t op
auto *scale_inv = reinterpret_cast<float *>(buffers[3]); auto *scale_inv = reinterpret_cast<float *>(buffers[3]);
auto *output = buffers[4]; auto *output = buffers[4];
auto *amax_out = reinterpret_cast<float *>(buffers[5]); auto *amax_out = reinterpret_cast<float *>(buffers[5]);
assert(amax == amax_out); NVTE_CHECK(
amax == amax_out,
"Internal TE/JAX error: amax_out should be bound to amax in the JAX primitive.");
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len); const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto shape = desc.shape.to_vector(); auto shape = desc.shape.to_vector();
......
...@@ -13,9 +13,12 @@ import jax.numpy as jnp ...@@ -13,9 +13,12 @@ import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict from flax.core.frozen_dict import FrozenDict
from flax.linen import fp8_ops from flax.linen import fp8_ops
from transformer_engine_jax import DType from transformer_engine.transformer_engine_jax import DType
from transformer_engine_jax import get_cublasLt_version from transformer_engine.transformer_engine_jax import get_cublasLt_version
from transformer_engine_jax import get_cuda_version, get_device_compute_capability from transformer_engine.transformer_engine_jax import (
get_cuda_version,
get_device_compute_capability,
)
from transformer_engine.common.recipe import DelayedScaling, Format from transformer_engine.common.recipe import DelayedScaling, Format
from transformer_engine.jax.sharding import global_shard_guard from transformer_engine.jax.sharding import global_shard_guard
from transformer_engine.jax.sharding import MeshResource from transformer_engine.jax.sharding import MeshResource
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment