"...git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "324be3324278723bd8f66196ed1ccac29b94bd7f"
Unverified Commit 4b2b39b4 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[TE/JAX] Prototype for New XLA Custom Calls with FFI (#946)



* implemented custom call with ffi in csrc

* moved headers of misc to misc.h, add ffi.h

* ActLu and DActLu lowering with ffi_lowering

* CastTranspose with ffi_lowering

* enabled cudaGraph

* added 4d input test case to TestActivationLu

* added operand_output_aliases for CastTranspose

* added env var NVTE_JAX_WITH_FFI, default value = 1

* replace casting ActivationEnum by taking its value

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent dcc50c8e
...@@ -11,6 +11,8 @@ from glob import glob ...@@ -11,6 +11,8 @@ from glob import glob
from .utils import cuda_path, all_files_in_dir from .utils import cuda_path, all_files_in_dir
from typing import List from typing import List
from jax.extend import ffi
def setup_jax_extension( def setup_jax_extension(
csrc_source_files, csrc_source_files,
...@@ -27,12 +29,14 @@ def setup_jax_extension( ...@@ -27,12 +29,14 @@ def setup_jax_extension(
# Header files # Header files
cuda_home, _ = cuda_path() cuda_home, _ = cuda_path()
jax_ffi_include = ffi.include_dir()
include_dirs = [ include_dirs = [
cuda_home / "include", cuda_home / "include",
common_header_files, common_header_files,
common_header_files / "common", common_header_files / "common",
common_header_files / "common" / "include", common_header_files / "common" / "include",
csrc_header_files, csrc_header_files,
jax_ffi_include,
] ]
# Compile flags # Compile flags
......
...@@ -432,7 +432,7 @@ class TestActivationLu: ...@@ -432,7 +432,7 @@ class TestActivationLu:
def primitive_func(self, inputs): def primitive_func(self, inputs):
return jnp.mean(activation_lu(inputs, activation_type=self.activation_type)) return jnp.mean(activation_lu(inputs, activation_type=self.activation_type))
@pytest.mark.parametrize("shape", [(32, 1, 64), (64, 1, 256)]) @pytest.mark.parametrize("shape", [(32, 1, 64), (16, 64, 1, 256)])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"activation_type", "activation_type",
[ [
...@@ -450,7 +450,7 @@ class TestActivationLu: ...@@ -450,7 +450,7 @@ class TestActivationLu:
) )
def test_activation_lu(self, random_inputs, activation_type): def test_activation_lu(self, random_inputs, activation_type):
x = random_inputs x = random_inputs
x = jnp.repeat(x, len(activation_type), axis=1) x = jnp.repeat(x, len(activation_type), axis=-2)
self.activation_type = activation_type self.activation_type = activation_type
value_n_grad_primitive_func = jit(value_and_grad(self.primitive_func, (0,))) value_n_grad_primitive_func = jit(value_and_grad(self.primitive_func, (0,)))
...@@ -511,7 +511,7 @@ class TestActivationLuFP8(TestActivationLu): ...@@ -511,7 +511,7 @@ class TestActivationLuFP8(TestActivationLu):
_prim_func.defvjp(_prim_func_fwd, _prim_func_bwd) _prim_func.defvjp(_prim_func_fwd, _prim_func_bwd)
dx_trans_no_use = jnp.empty([x.shape[i] for i in self.transpose_indices], dtype=x.dtype) dx_trans_no_use = jnp.empty([x.shape[i] for i in self.transpose_axes], dtype=x.dtype)
dbias_no_use = jnp.empty(x.shape[-1], dtype=x.dtype) dbias_no_use = jnp.empty(x.shape[-1], dtype=x.dtype)
amax_no_use = jnp.zeros(1, jnp.float32) amax_no_use = jnp.zeros(1, jnp.float32)
value_n_grad_primitive_func = value_and_grad( value_n_grad_primitive_func = value_and_grad(
...@@ -520,7 +520,7 @@ class TestActivationLuFP8(TestActivationLu): ...@@ -520,7 +520,7 @@ class TestActivationLuFP8(TestActivationLu):
return value_n_grad_primitive_func(x, dx_trans_no_use, dbias_no_use, amax_no_use) return value_n_grad_primitive_func(x, dx_trans_no_use, dbias_no_use, amax_no_use)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("shape", [(32, 1, 64), (64, 1, 256)]) @pytest.mark.parametrize("shape", [(32, 1, 64), (16, 64, 1, 256)])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"activation_type", "activation_type",
[ [
...@@ -541,10 +541,12 @@ class TestActivationLuFP8(TestActivationLu): ...@@ -541,10 +541,12 @@ class TestActivationLuFP8(TestActivationLu):
self.scale = jnp.ones(1, jnp.float32) self.scale = jnp.ones(1, jnp.float32)
self.scale_inv = jnp.ones(1, jnp.float32) self.scale_inv = jnp.ones(1, jnp.float32)
self.activation_type = activation_type self.activation_type = activation_type
self.transpose_indices = (1, 2, 0)
x = random_inputs x = random_inputs
x = jnp.repeat(x, len(activation_type), axis=1) x = jnp.repeat(x, len(activation_type), axis=-2)
axes = jnp.arange(x.ndim)
self.transpose_axes = tuple([*axes[-2:]] + [*axes[:-2]])
print(self.transpose_axes)
prim_out, (prim_grad, prim_grad_trans, dbias, amax) = self.prim_func(x) prim_out, (prim_grad, prim_grad_trans, dbias, amax) = self.prim_func(x)
ref_out, (ref_grad,) = self.ref_func(x, activation_type) ref_out, (ref_grad,) = self.ref_func(x, activation_type)
...@@ -556,7 +558,7 @@ class TestActivationLuFP8(TestActivationLu): ...@@ -556,7 +558,7 @@ class TestActivationLuFP8(TestActivationLu):
assert_allclose(prim_grad, ref_grad, dtype=FP8Helper.BWD_DTYPE) assert_allclose(prim_grad, ref_grad, dtype=FP8Helper.BWD_DTYPE)
assert_allclose( assert_allclose(
prim_grad_trans, prim_grad_trans,
jnp.transpose(ref_grad, self.transpose_indices), jnp.transpose(ref_grad, self.transpose_axes),
dtype=FP8Helper.BWD_DTYPE, dtype=FP8Helper.BWD_DTYPE,
) )
......
...@@ -11,6 +11,7 @@ import jax.numpy as jnp ...@@ -11,6 +11,7 @@ import jax.numpy as jnp
from jax import core, dtypes from jax import core, dtypes
from jax.interpreters.mlir import ir from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding from jax.sharding import PartitionSpec, NamedSharding
from jax.extend import ffi
from transformer_engine import transformer_engine_jax from transformer_engine import transformer_engine_jax
from transformer_engine.transformer_engine_jax import NVTE_Activation_Type from transformer_engine.transformer_engine_jax import NVTE_Activation_Type
...@@ -22,6 +23,7 @@ from .misc import ( ...@@ -22,6 +23,7 @@ from .misc import (
jax_dtype_to_te_dtype, jax_dtype_to_te_dtype,
jax_dtype_to_ir_dtype, jax_dtype_to_ir_dtype,
get_padded_spec, get_padded_spec,
is_ffi_enabled,
) )
from .quantization import _jax_cast_fp8 from .quantization import _jax_cast_fp8
from ..sharding import all_reduce_max_along_all_axes_except_PP from ..sharding import all_reduce_max_along_all_axes_except_PP
...@@ -109,6 +111,10 @@ class ActLuPrimitive(BasePrimitive): ...@@ -109,6 +111,10 @@ class ActLuPrimitive(BasePrimitive):
""" """
(x_aval,) = ctx.avals_in (x_aval,) = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
if is_ffi_enabled():
name = "te_act_lu_ffi"
out = ffi.ffi_lowering(name)(ctx, x, act_enum=act_enum)
else:
ir_x_type = ir.RankedTensorType(x.type) ir_x_type = ir.RankedTensorType(x.type)
ir_x_shape = ir_x_type.shape ir_x_shape = ir_x_type.shape
out_shape = ir_x_shape[:-2] + [ir_x_shape[-1]] out_shape = ir_x_shape[:-2] + [ir_x_shape[-1]]
...@@ -189,7 +195,7 @@ def act_lu(inputs: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]) ...@@ -189,7 +195,7 @@ def act_lu(inputs: jnp.ndarray, activation_type: Sequence[Union[str, Callable]])
if not ActLuPrimitive.enabled(): if not ActLuPrimitive.enabled():
return _jax_act_lu(inputs, activation_type) return _jax_act_lu(inputs, activation_type)
act_type_id = ActivationEnum[activation_type] act_type_id = ActivationEnum[activation_type].value
return ActLuPrimitive.outer_primitive.bind(inputs, act_enum=act_type_id) return ActLuPrimitive.outer_primitive.bind(inputs, act_enum=act_type_id)
...@@ -231,6 +237,10 @@ class DActLuPrimitive(BasePrimitive): ...@@ -231,6 +237,10 @@ class DActLuPrimitive(BasePrimitive):
in_aval, gi_aval = ctx.avals_in in_aval, gi_aval = ctx.avals_in
assert in_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert in_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert gi_aval.dtype == in_aval.dtype assert gi_aval.dtype == in_aval.dtype
if is_ffi_enabled():
name = "te_dact_lu_ffi"
out = ffi.ffi_lowering(name)(ctx, dz, x, act_enum=act_enum)
else:
ir_in_type = ir.RankedTensorType(dz.type) ir_in_type = ir.RankedTensorType(dz.type)
ir_in_shape = ir_in_type.shape ir_in_shape = ir_in_type.shape
gi_type = ir.RankedTensorType(x.type) gi_type = ir.RankedTensorType(x.type)
...@@ -320,12 +330,11 @@ def dact_lu( ...@@ -320,12 +330,11 @@ def dact_lu(
dact_lu fusion wrapper dact_lu fusion wrapper
Return dgated_act_lu(inputs) Return dgated_act_lu(inputs)
""" """
if not DActLuPrimitive.enabled(): if not DActLuPrimitive.enabled():
_, vjp_func = jax.vjp(partial(_jax_act_lu, activation_type=activation_type), act_lu_inputs) _, vjp_func = jax.vjp(partial(_jax_act_lu, activation_type=activation_type), act_lu_inputs)
return vjp_func(inputs)[0] return vjp_func(inputs)[0]
act_type_id = ActivationEnum[activation_type] act_type_id = ActivationEnum[activation_type].value
return DActLuPrimitive.outer_primitive.bind(inputs, act_lu_inputs, act_enum=act_type_id) return DActLuPrimitive.outer_primitive.bind(inputs, act_lu_inputs, act_enum=act_type_id)
...@@ -487,7 +496,7 @@ def act_lu_fp8( ...@@ -487,7 +496,7 @@ def act_lu_fp8(
casted_output, updated_amax = _jax_cast_fp8(act_lu_output, scale, amax, out_dtype) casted_output, updated_amax = _jax_cast_fp8(act_lu_output, scale, amax, out_dtype)
return casted_output, updated_amax return casted_output, updated_amax
act_type_id = ActivationEnum[activation_type] act_type_id = ActivationEnum[activation_type].value
return ActLuFp8Primitive.outer_primitive.bind( return ActLuFp8Primitive.outer_primitive.bind(
x, amax, scale, scale_inv, out_dtype=out_dtype, act_enum=act_type_id x, amax, scale, scale_inv, out_dtype=out_dtype, act_enum=act_type_id
) )
...@@ -3,12 +3,14 @@ ...@@ -3,12 +3,14 @@
# See LICENSE for license information. # See LICENSE for license information.
"""JAX/TE custom call""" """JAX/TE custom call"""
from dataclasses import dataclass from dataclasses import dataclass
from enum import IntEnum
from jax.lib import xla_client from jax.lib import xla_client
from jax.interpreters import mlir from jax.interpreters import mlir
from transformer_engine import transformer_engine_jax from transformer_engine import transformer_engine_jax
from .misc import is_ffi_enabled
try: try:
from jaxlib.hlo_helpers import custom_call from jaxlib.hlo_helpers import custom_call
...@@ -17,8 +19,25 @@ except ImportError: ...@@ -17,8 +19,25 @@ except ImportError:
# version, so we still need this import. # version, so we still need this import.
pass pass
class CustomCallAPIVersion(IntEnum):
"""Enum for selecting between old and new custom call registration API"""
OPAQUE = 0
FFI = 1
for _name, _value in transformer_engine_jax.registrations().items(): for _name, _value in transformer_engine_jax.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA") if _name.endswith("_ffi"):
if is_ffi_enabled():
# COMMAND_BUFFER_COMPATIBLE i.e. cudaGraph enabled by default
xla_client.register_custom_call_target(
_name, _value, platform="CUDA", api_version=CustomCallAPIVersion.FFI.value
)
else:
xla_client.register_custom_call_target(
_name, _value, platform="CUDA", api_version=CustomCallAPIVersion.OPAQUE.value
)
@dataclass @dataclass
...@@ -79,7 +98,7 @@ def custom_caller(name, args, opaque, has_side_effect, **kwargs): ...@@ -79,7 +98,7 @@ def custom_caller(name, args, opaque, has_side_effect, **kwargs):
result_layouts=args.output_layouts, result_layouts=args.output_layouts,
backend_config=opaque, backend_config=opaque,
has_side_effect=has_side_effect, has_side_effect=has_side_effect,
**kwargs **kwargs,
).results ).results
else: else:
# Need to disable one pylint error as the second function # Need to disable one pylint error as the second function
...@@ -93,6 +112,6 @@ def custom_caller(name, args, opaque, has_side_effect, **kwargs): ...@@ -93,6 +112,6 @@ def custom_caller(name, args, opaque, has_side_effect, **kwargs):
result_layouts=args.output_layouts, result_layouts=args.output_layouts,
backend_config=opaque, backend_config=opaque,
has_side_effect=has_side_effect, has_side_effect=has_side_effect,
**kwargs **kwargs,
) )
return out return out
...@@ -3,10 +3,14 @@ ...@@ -3,10 +3,14 @@
# See LICENSE for license information. # See LICENSE for license information.
"""JAX/TE miscellaneous for custom ops""" """JAX/TE miscellaneous for custom ops"""
import os
import functools import functools
from typing import Tuple from typing import Tuple
from importlib.metadata import version as get_pkg_version
from packaging.version import Version as PkgVersion
import numpy as np import numpy as np
import jax.numpy as jnp import jax.numpy as jnp
from jax import dtypes from jax import dtypes
from jax.interpreters.mlir import dtype_to_ir_type from jax.interpreters.mlir import dtype_to_ir_type
...@@ -142,3 +146,24 @@ def get_cudnn_version() -> Tuple[int, int, int]: ...@@ -142,3 +146,24 @@ def get_cudnn_version() -> Tuple[int, int, int]:
major, encoded_version = divmod(encoded_version, major_version_magnitude) major, encoded_version = divmod(encoded_version, major_version_magnitude)
minor, patch = divmod(encoded_version, 100) minor, patch = divmod(encoded_version, 100)
return (major, minor, patch) return (major, minor, patch)
@functools.lru_cache(maxsize=None)
def jax_version_meet_requirement(version: str):
"""
Helper function checking if required JAX version is available
"""
jax_version = PkgVersion(get_pkg_version("jax"))
jax_version_required = PkgVersion(version)
return jax_version >= jax_version_required
def is_ffi_enabled():
"""
Helper function checking if XLA Custom Call with FFI is enabled
"""
is_supported = jax_version_meet_requirement("0.4.31")
# New APIs with FFI are enabled by default
is_enabled = int(os.getenv("NVTE_JAX_WITH_FFI", "1"))
assert is_enabled in (0, 1), "Invalid NVTE_JAX_WITH_FFI value"
return is_supported and is_enabled
...@@ -11,6 +11,7 @@ import jax.numpy as jnp ...@@ -11,6 +11,7 @@ import jax.numpy as jnp
from jax import dtypes from jax import dtypes
from jax.interpreters.mlir import ir from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding from jax.sharding import PartitionSpec, NamedSharding
from jax.extend import ffi
from transformer_engine import transformer_engine_jax from transformer_engine import transformer_engine_jax
from transformer_engine.transformer_engine_jax import DType as TEDType from transformer_engine.transformer_engine_jax import DType as TEDType
...@@ -25,6 +26,7 @@ from .misc import ( ...@@ -25,6 +26,7 @@ from .misc import (
get_padded_spec, get_padded_spec,
multidim_transpose, multidim_transpose,
normalize_axis_boundary, normalize_axis_boundary,
is_ffi_enabled,
) )
from .activation import ActivationEnum from .activation import ActivationEnum
from .activation import _jax_act_lu from .activation import _jax_act_lu
...@@ -262,6 +264,12 @@ class CastTransposePrimitive(BasePrimitive): ...@@ -262,6 +264,12 @@ class CastTransposePrimitive(BasePrimitive):
assert amax_aval.dtype == jnp.float32 assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32 assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32 assert scale_inv_aval.dtype == jnp.float32
if is_ffi_enabled():
name = "te_cast_transpose_ffi"
out = ffi.ffi_lowering(name, operand_output_aliases={1: 2})(
ctx, x, amax, scale, scale_inv, transpose_axis=transpose_axis_boundary
)
else:
ir_x_type = ir.RankedTensorType(x.type) ir_x_type = ir.RankedTensorType(x.type)
ir_x_shape = ir_x_type.shape ir_x_shape = ir_x_type.shape
if static_axis_boundary >= 0: if static_axis_boundary >= 0:
...@@ -296,11 +304,9 @@ class CastTransposePrimitive(BasePrimitive): ...@@ -296,11 +304,9 @@ class CastTransposePrimitive(BasePrimitive):
jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(out_dtype), jax_dtype_to_te_dtype(out_dtype),
) )
out = custom_caller( out = custom_caller(
CastTransposePrimitive.name, args, opaque, False, operand_output_aliases={1: 2} CastTransposePrimitive.name, args, opaque, False, operand_output_aliases={1: 2}
) )
return out return out
@staticmethod @staticmethod
......
...@@ -13,8 +13,6 @@ ...@@ -13,8 +13,6 @@
#include <cudnn.h> #include <cudnn.h>
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include <transformer_engine/activation.h>
#include <transformer_engine/fused_attn.h>
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
#include <cassert> #include <cassert>
...@@ -27,23 +25,14 @@ ...@@ -27,23 +25,14 @@
#include "common/common.h" #include "common/common.h"
#include "common/util/logging.h" #include "common/util/logging.h"
#include "extensions/ffi.h"
#include "extensions/misc.h"
#include "transformer_engine/activation.h"
#include "utils.h" #include "utils.h"
namespace transformer_engine { namespace transformer_engine {
namespace jax { namespace jax {
constexpr int kMaxNumDim = 8;
// TODO: Rename Shape to ???
struct Shape {
int num_dim;
size_t dims[kMaxNumDim];
void from_vector(const std::vector<size_t> &shape);
std::vector<size_t> to_vector() const;
};
// Phuong: These 3 functions need to stay in the header file for compilation purpose // Phuong: These 3 functions need to stay in the header file for compilation purpose
// 1. // 1.
inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; } inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; }
...@@ -62,8 +51,6 @@ const T *UnpackOpaque(const char *opaque, size_t opaque_len) { ...@@ -62,8 +51,6 @@ const T *UnpackOpaque(const char *opaque, size_t opaque_len) {
return reinterpret_cast<const T *>(opaque); return reinterpret_cast<const T *>(opaque);
} }
std::vector<size_t> MakeShapeVector(NVTEShape shape);
// Packing // Packing
struct CustomCallCommonDescriptor { struct CustomCallCommonDescriptor {
...@@ -167,6 +154,8 @@ void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size ...@@ -167,6 +154,8 @@ void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size
pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype); DType in_dtype, DType out_dtype);
XLA_FFI_DECLARE_HANDLER_SYMBOL(CastTransposeHandler);
void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
// Activation // Activation
...@@ -179,6 +168,10 @@ void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t op ...@@ -179,6 +168,10 @@ void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t op
void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuHandler);
pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype); DType in_dtype, DType out_dtype);
......
...@@ -3,15 +3,16 @@ ...@@ -3,15 +3,16 @@
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "transformer_engine/activation.h" #include "transformer_engine/activation.h"
#include "extensions.h" #include "extensions.h"
#include "transformer_engine/transpose.h" #include "transformer_engine/transpose.h"
#include "xla/ffi/api/c_api.h"
namespace transformer_engine { namespace transformer_engine {
namespace jax { namespace jax {
// TODO: We won't need this function anymore when we move to the new XLA custom calls
size_t get_activation_len(NVTE_Activation_Type activation_enum) { size_t get_activation_len(NVTE_Activation_Type activation_enum) {
switch (activation_enum) { switch (activation_enum) {
case NVTE_Activation_Type::GELU: case NVTE_Activation_Type::GELU:
...@@ -43,8 +44,7 @@ size_t get_activation_len(NVTE_Activation_Type activation_enum) { ...@@ -43,8 +44,7 @@ size_t get_activation_len(NVTE_Activation_Type activation_enum) {
void ActLuImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale, void ActLuImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale,
cudaStream_t stream, float *scale_inverse, float *amax, void *output, cudaStream_t stream, float *scale_inverse, float *amax, void *output,
NVTE_Activation_Type act_enum) { NVTE_Activation_Type act_enum, size_t act_len) {
auto act_len = get_activation_len(act_enum);
auto input_shape = std::vector<size_t>{m, n * act_len}; auto input_shape = std::vector<size_t>{m, n * act_len};
auto output_shape = std::vector<size_t>{m, n}; auto output_shape = std::vector<size_t>{m, n};
auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype)); auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype));
...@@ -95,12 +95,39 @@ void ActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaqu ...@@ -95,12 +95,39 @@ void ActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaqu
auto m = desc.shape.dims[0]; auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1]; auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum); auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);
; auto act_len = get_activation_len(act_enum);
ActLuImpl(input, m, n, desc.in_dtype, desc.out_dtype, nullptr, stream, nullptr, nullptr, output, ActLuImpl(input, m, n, desc.in_dtype, desc.out_dtype, nullptr, stream, nullptr, nullptr, output,
act_enum); act_enum, act_len);
} }
Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type output_buf,
int64_t act_enum) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto *input = input_buf.untyped_data();
auto *output = output_buf->untyped_data();
auto input_dims = input_buf.dimensions();
auto m = std::accumulate(input_dims.begin(), input_dims.end() - 2, 1, std::multiplies<>());
auto n = input_dims.back();
auto act_len = input_dims.end()[-2];
auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
ActLuImpl(input, m, n, in_dtype, out_dtype, nullptr, stream, nullptr, nullptr, output, act_type,
act_len);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Ret<Buffer_Type>() // output
.Attr<int64_t>("act_enum"));
void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0]; auto *input = buffers[0];
float *amax = reinterpret_cast<float *>(buffers[1]); float *amax = reinterpret_cast<float *>(buffers[1]);
...@@ -119,10 +146,10 @@ void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t op ...@@ -119,10 +146,10 @@ void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t op
auto m = desc.shape.dims[0]; auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1]; auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum); auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);
; auto act_len = get_activation_len(act_enum);
ActLuImpl(input, m, n, desc.in_dtype, desc.out_dtype, scale, stream, scale_inv, amax_out, output, ActLuImpl(input, m, n, desc.in_dtype, desc.out_dtype, scale, stream, scale_inv, amax_out, output,
act_enum); act_enum, act_len);
} }
void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
...@@ -134,7 +161,6 @@ void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaq ...@@ -134,7 +161,6 @@ void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaq
auto m = desc.shape.dims[0]; auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1]; auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum); auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);
;
auto act_len = get_activation_len(act_enum); auto act_len = get_activation_len(act_enum);
auto input_shape = std::vector<size_t>{m, n}; auto input_shape = std::vector<size_t>{m, n};
...@@ -182,6 +208,76 @@ void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaq ...@@ -182,6 +208,76 @@ void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaq
} }
} }
Error_Type DActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type act_input_buf,
Result_Type output_buf, int64_t act_enum) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto *input = input_buf.untyped_data();
auto *act_input = act_input_buf.untyped_data();
auto *output = output_buf->untyped_data();
auto act_input_dims = act_input_buf.dimensions();
auto m =
std::accumulate(act_input_dims.begin(), act_input_dims.end() - 2, 1, std::multiplies<>());
auto n = act_input_dims.back();
auto act_len = act_input_dims.end()[-2];
auto input_shape = std::vector<size_t>{m, n};
auto act_input_shape = std::vector<size_t>{m, n * act_len};
auto output_shape = std::vector<size_t>{m, n * act_len};
auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype));
auto act_input_tensor = TensorWrapper(act_input, act_input_shape, static_cast<DType>(in_dtype));
auto output_tensor = TensorWrapper(output, output_shape, static_cast<DType>(out_dtype));
auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
switch (act_type) {
case NVTE_Activation_Type::GELU:
nvte_dgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::GEGLU:
nvte_dgeglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SILU:
nvte_dsilu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SWIGLU:
nvte_dswiglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::RELU:
nvte_drelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::REGLU:
nvte_dreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::QGELU:
nvte_dqgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::QGEGLU:
nvte_dqgeglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SRELU:
nvte_dsrelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SREGLU:
nvte_dsreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break;
default:
NVTE_ERROR("Unsupported ActivationEnum");
break;
}
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuHandler, DActLuFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // act_input
.Ret<Buffer_Type>() // output
.Attr<int64_t>("act_enum"));
pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype) { DType in_dtype, DType out_dtype) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size}; auto input_shape = std::vector<size_t>{batch_size, hidden_size};
......
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions/ffi.h"
#include <iostream>
#include "common/util/logging.h"
namespace transformer_engine {
namespace jax {
// For XLA_FFI_DataType Enum Reference: https://github.com/openxla/xla/blob/d054e8366c4e8807726961feeb28b1cdba681888/xla/ffi/api/c_api.h#L163-L186
DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type) {
switch (type) {
case xla::ffi::DataType::F16:
return DType::kFloat16;
break;
case xla::ffi::DataType::F32:
return DType::kFloat32;
break;
case xla::ffi::DataType::BF16:
return DType::kBFloat16;
break;
case xla::ffi::DataType::F8E5M2:
return DType::kFloat8E5M2;
break;
case xla::ffi::DataType::F8E4M3FN:
return DType::kFloat8E4M3;
break;
default:
auto type_num = static_cast<XLA_FFI_DataType>(type);
NVTE_ERROR("TE does not support conversion of XLA_FFI_DataType %d",
static_cast<int>(type_num));
break;
}
}
Error_Type ffi_with_cuda_error_check() {
cudaError_t last_error = cudaGetLastError();
if (last_error != cudaSuccess) {
return Error_Type(XLA_FFI_Error_Code_INTERNAL,
std::string("CUDA error: ") + cudaGetErrorString(last_error));
}
return Error_Type::Success();
}
} // namespace jax
} // namespace transformer_engine
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/transformer_engine.h>
#include <xla/ffi/api/ffi.h>
#include <numeric>
namespace transformer_engine {
namespace jax {
using Buffer_Type = xla::ffi::AnyBuffer;
using Result_Type = xla::ffi::Result<xla::ffi::AnyBuffer>;
using Error_Type = xla::ffi::Error;
using FFI = xla::ffi::Ffi;
using FFI_Stream_Type = xla::ffi::PlatformStream<cudaStream_t>;
DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type);
Error_Type ffi_with_cuda_error_check();
} // namespace jax
} // namespace transformer_engine
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/transformer_engine.h>
#include <cassert>
#include <string>
#include <vector>
namespace transformer_engine {
namespace jax {
constexpr int kMaxNumDim = 8;
struct Shape {
int num_dim;
size_t dims[kMaxNumDim];
void from_vector(const std::vector<size_t> &shape);
std::vector<size_t> to_vector() const;
};
std::vector<size_t> MakeShapeVector(NVTEShape shape);
} // namespace jax
} // namespace transformer_engine
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "extensions.h" #include "extensions.h"
#include "transformer_engine/layer_norm.h" #include "transformer_engine/layer_norm.h"
#include "transformer_engine/rmsnorm.h" #include "transformer_engine/rmsnorm.h"
......
...@@ -14,6 +14,13 @@ pybind11::capsule EncapsulateFunction(T *fn) { ...@@ -14,6 +14,13 @@ pybind11::capsule EncapsulateFunction(T *fn) {
return pybind11::capsule(reinterpret_cast<void *>(fn), "xla._CUSTOM_CALL_TARGET"); return pybind11::capsule(reinterpret_cast<void *>(fn), "xla._CUSTOM_CALL_TARGET");
} }
template <typename T>
pybind11::capsule EncapsulateFFI(T *fn) {
static_assert(std::is_invocable_r_v<XLA_FFI_Error *, T, XLA_FFI_CallFrame *>,
"Encapsulated function must be an XLA FFI handler");
return pybind11::capsule(reinterpret_cast<void *>(fn), "xla._CUSTOM_CALL_TARGET");
}
pybind11::dict Registrations() { pybind11::dict Registrations() {
pybind11::dict dict; pybind11::dict dict;
dict["te_transpose"] = EncapsulateFunction(Transpose); dict["te_transpose"] = EncapsulateFunction(Transpose);
...@@ -44,6 +51,10 @@ pybind11::dict Registrations() { ...@@ -44,6 +51,10 @@ pybind11::dict Registrations() {
EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackward); EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackward);
dict["te_fused_attn_forward"] = EncapsulateFunction(FusedAttnForward); dict["te_fused_attn_forward"] = EncapsulateFunction(FusedAttnForward);
dict["te_fused_attn_backward"] = EncapsulateFunction(FusedAttnBackward); dict["te_fused_attn_backward"] = EncapsulateFunction(FusedAttnBackward);
dict["te_cast_transpose_ffi"] = EncapsulateFFI(CastTransposeHandler);
dict["te_act_lu_ffi"] = EncapsulateFFI(ActLuHandler);
dict["te_dact_lu_ffi"] = EncapsulateFFI(DActLuHandler);
return dict; return dict;
} }
...@@ -114,7 +125,8 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -114,7 +125,8 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.value("QGELU", NVTE_Activation_Type::QGELU) .value("QGELU", NVTE_Activation_Type::QGELU)
.value("QGEGLU", NVTE_Activation_Type::QGEGLU) .value("QGEGLU", NVTE_Activation_Type::QGEGLU)
.value("SRELU", NVTE_Activation_Type::SRELU) .value("SRELU", NVTE_Activation_Type::SRELU)
.value("SREGLU", NVTE_Activation_Type::SREGLU); .value("SREGLU", NVTE_Activation_Type::SREGLU)
.export_values();
pybind11::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) pybind11::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend", pybind11::module_local())
.value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend) .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend)
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "transformer_engine/transpose.h" #include "transformer_engine/transpose.h"
#include "extensions.h" #include "extensions.h"
#include "xla/ffi/api/ffi.h"
namespace transformer_engine { namespace transformer_engine {
namespace jax { namespace jax {
...@@ -66,6 +67,61 @@ void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size ...@@ -66,6 +67,61 @@ void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size
stream); stream);
} }
Error_Type CastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf,
Buffer_Type scale_buf, Buffer_Type scale_inv_buf,
Result_Type input_cast_buf, Result_Type input_cast_trans_buf,
Result_Type amax_out_buf, int64_t transpose_axis) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(input_cast_buf->element_type());
auto *input = input_buf.untyped_data();
float *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *scale_inv = reinterpret_cast<float *>(scale_inv_buf.untyped_data());
auto *input_cast = input_cast_buf->untyped_data();
auto *input_cast_trans = input_cast_trans_buf->untyped_data();
float *amax_out = reinterpret_cast<float *>(amax_out_buf->untyped_data());
assert(amax == amax_out);
if (!use_fp8(out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto input_dims = input_buf.dimensions();
if (transpose_axis < 0) transpose_axis += input_dims.size();
auto m = std::accumulate(input_dims.begin(), input_dims.begin() + transpose_axis, 1,
std::multiplies<>());
auto n = std::accumulate(input_dims.begin() + transpose_axis, input_dims.end(), 1,
std::multiplies<>());
auto input_shape = std::vector<size_t>{m, n};
auto input_trans_shape = std::vector<size_t>{n, m};
auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
auto input_cast_tensor =
TensorWrapper(input_cast, input_shape, out_dtype, amax_out, scale, scale_inv);
auto input_cast_trans_tensor =
TensorWrapper(input_cast_trans, input_trans_shape, out_dtype, amax_out, scale, scale_inv);
nvte_cast_transpose(input_tensor.data(), input_cast_tensor.data(), input_cast_trans_tensor.data(),
stream);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(CastTransposeHandler, CastTransposeFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // input_cast
.Ret<Buffer_Type>() // input_cast_trans
.Ret<Buffer_Type>() // amax_out
.Attr<int64_t>("transpose_axis"));
pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype) { DType in_dtype, DType out_dtype) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size}; auto input_shape = std::vector<size_t>{batch_size, hidden_size};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment