Unverified Commit 4b2b39b4 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[TE/JAX] Prototype for New XLA Custom Calls with FFI (#946)



* implemented custom call with ffi in csrc

* moved headers of misc to misc.h, add ffi.h

* ActLu and DActLu lowering with ffi_lowering

* CastTranspose with ffi_lowering

* enabled cudaGraph

* added 4d input test case to TestActivationLu

* added operand_output_aliases for CastTranspose

* added env var NVTE_JAX_WITH_FFI, default value = 1

* replace casting ActivationEnum by taking its value

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent dcc50c8e
......@@ -11,6 +11,8 @@ from glob import glob
from .utils import cuda_path, all_files_in_dir
from typing import List
from jax.extend import ffi
def setup_jax_extension(
csrc_source_files,
......@@ -27,12 +29,14 @@ def setup_jax_extension(
# Header files
cuda_home, _ = cuda_path()
jax_ffi_include = ffi.include_dir()
include_dirs = [
cuda_home / "include",
common_header_files,
common_header_files / "common",
common_header_files / "common" / "include",
csrc_header_files,
jax_ffi_include,
]
# Compile flags
......
......@@ -432,7 +432,7 @@ class TestActivationLu:
def primitive_func(self, inputs):
return jnp.mean(activation_lu(inputs, activation_type=self.activation_type))
@pytest.mark.parametrize("shape", [(32, 1, 64), (64, 1, 256)])
@pytest.mark.parametrize("shape", [(32, 1, 64), (16, 64, 1, 256)])
@pytest.mark.parametrize(
"activation_type",
[
......@@ -450,7 +450,7 @@ class TestActivationLu:
)
def test_activation_lu(self, random_inputs, activation_type):
x = random_inputs
x = jnp.repeat(x, len(activation_type), axis=1)
x = jnp.repeat(x, len(activation_type), axis=-2)
self.activation_type = activation_type
value_n_grad_primitive_func = jit(value_and_grad(self.primitive_func, (0,)))
......@@ -511,7 +511,7 @@ class TestActivationLuFP8(TestActivationLu):
_prim_func.defvjp(_prim_func_fwd, _prim_func_bwd)
dx_trans_no_use = jnp.empty([x.shape[i] for i in self.transpose_indices], dtype=x.dtype)
dx_trans_no_use = jnp.empty([x.shape[i] for i in self.transpose_axes], dtype=x.dtype)
dbias_no_use = jnp.empty(x.shape[-1], dtype=x.dtype)
amax_no_use = jnp.zeros(1, jnp.float32)
value_n_grad_primitive_func = value_and_grad(
......@@ -520,7 +520,7 @@ class TestActivationLuFP8(TestActivationLu):
return value_n_grad_primitive_func(x, dx_trans_no_use, dbias_no_use, amax_no_use)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("shape", [(32, 1, 64), (64, 1, 256)])
@pytest.mark.parametrize("shape", [(32, 1, 64), (16, 64, 1, 256)])
@pytest.mark.parametrize(
"activation_type",
[
......@@ -541,10 +541,12 @@ class TestActivationLuFP8(TestActivationLu):
self.scale = jnp.ones(1, jnp.float32)
self.scale_inv = jnp.ones(1, jnp.float32)
self.activation_type = activation_type
self.transpose_indices = (1, 2, 0)
x = random_inputs
x = jnp.repeat(x, len(activation_type), axis=1)
x = jnp.repeat(x, len(activation_type), axis=-2)
axes = jnp.arange(x.ndim)
self.transpose_axes = tuple([*axes[-2:]] + [*axes[:-2]])
print(self.transpose_axes)
prim_out, (prim_grad, prim_grad_trans, dbias, amax) = self.prim_func(x)
ref_out, (ref_grad,) = self.ref_func(x, activation_type)
......@@ -556,7 +558,7 @@ class TestActivationLuFP8(TestActivationLu):
assert_allclose(prim_grad, ref_grad, dtype=FP8Helper.BWD_DTYPE)
assert_allclose(
prim_grad_trans,
jnp.transpose(ref_grad, self.transpose_indices),
jnp.transpose(ref_grad, self.transpose_axes),
dtype=FP8Helper.BWD_DTYPE,
)
......
......@@ -11,6 +11,7 @@ import jax.numpy as jnp
from jax import core, dtypes
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
from jax.extend import ffi
from transformer_engine import transformer_engine_jax
from transformer_engine.transformer_engine_jax import NVTE_Activation_Type
......@@ -22,6 +23,7 @@ from .misc import (
jax_dtype_to_te_dtype,
jax_dtype_to_ir_dtype,
get_padded_spec,
is_ffi_enabled,
)
from .quantization import _jax_cast_fp8
from ..sharding import all_reduce_max_along_all_axes_except_PP
......@@ -109,6 +111,10 @@ class ActLuPrimitive(BasePrimitive):
"""
(x_aval,) = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
if is_ffi_enabled():
name = "te_act_lu_ffi"
out = ffi.ffi_lowering(name)(ctx, x, act_enum=act_enum)
else:
ir_x_type = ir.RankedTensorType(x.type)
ir_x_shape = ir_x_type.shape
out_shape = ir_x_shape[:-2] + [ir_x_shape[-1]]
......@@ -189,7 +195,7 @@ def act_lu(inputs: jnp.ndarray, activation_type: Sequence[Union[str, Callable]])
if not ActLuPrimitive.enabled():
return _jax_act_lu(inputs, activation_type)
act_type_id = ActivationEnum[activation_type]
act_type_id = ActivationEnum[activation_type].value
return ActLuPrimitive.outer_primitive.bind(inputs, act_enum=act_type_id)
......@@ -231,6 +237,10 @@ class DActLuPrimitive(BasePrimitive):
in_aval, gi_aval = ctx.avals_in
assert in_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert gi_aval.dtype == in_aval.dtype
if is_ffi_enabled():
name = "te_dact_lu_ffi"
out = ffi.ffi_lowering(name)(ctx, dz, x, act_enum=act_enum)
else:
ir_in_type = ir.RankedTensorType(dz.type)
ir_in_shape = ir_in_type.shape
gi_type = ir.RankedTensorType(x.type)
......@@ -320,12 +330,11 @@ def dact_lu(
dact_lu fusion wrapper
Return dgated_act_lu(inputs)
"""
if not DActLuPrimitive.enabled():
_, vjp_func = jax.vjp(partial(_jax_act_lu, activation_type=activation_type), act_lu_inputs)
return vjp_func(inputs)[0]
act_type_id = ActivationEnum[activation_type]
act_type_id = ActivationEnum[activation_type].value
return DActLuPrimitive.outer_primitive.bind(inputs, act_lu_inputs, act_enum=act_type_id)
......@@ -487,7 +496,7 @@ def act_lu_fp8(
casted_output, updated_amax = _jax_cast_fp8(act_lu_output, scale, amax, out_dtype)
return casted_output, updated_amax
act_type_id = ActivationEnum[activation_type]
act_type_id = ActivationEnum[activation_type].value
return ActLuFp8Primitive.outer_primitive.bind(
x, amax, scale, scale_inv, out_dtype=out_dtype, act_enum=act_type_id
)
......@@ -3,12 +3,14 @@
# See LICENSE for license information.
"""JAX/TE custom call"""
from dataclasses import dataclass
from enum import IntEnum
from jax.lib import xla_client
from jax.interpreters import mlir
from transformer_engine import transformer_engine_jax
from .misc import is_ffi_enabled
try:
from jaxlib.hlo_helpers import custom_call
......@@ -17,8 +19,25 @@ except ImportError:
# version, so we still need this import.
pass
class CustomCallAPIVersion(IntEnum):
"""Enum for selecting between old and new custom call registration API"""
OPAQUE = 0
FFI = 1
for _name, _value in transformer_engine_jax.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
if _name.endswith("_ffi"):
if is_ffi_enabled():
# COMMAND_BUFFER_COMPATIBLE i.e. cudaGraph enabled by default
xla_client.register_custom_call_target(
_name, _value, platform="CUDA", api_version=CustomCallAPIVersion.FFI.value
)
else:
xla_client.register_custom_call_target(
_name, _value, platform="CUDA", api_version=CustomCallAPIVersion.OPAQUE.value
)
@dataclass
......@@ -79,7 +98,7 @@ def custom_caller(name, args, opaque, has_side_effect, **kwargs):
result_layouts=args.output_layouts,
backend_config=opaque,
has_side_effect=has_side_effect,
**kwargs
**kwargs,
).results
else:
# Need to disable one pylint error as the second function
......@@ -93,6 +112,6 @@ def custom_caller(name, args, opaque, has_side_effect, **kwargs):
result_layouts=args.output_layouts,
backend_config=opaque,
has_side_effect=has_side_effect,
**kwargs
**kwargs,
)
return out
......@@ -3,10 +3,14 @@
# See LICENSE for license information.
"""JAX/TE miscellaneous for custom ops"""
import os
import functools
from typing import Tuple
from importlib.metadata import version as get_pkg_version
from packaging.version import Version as PkgVersion
import numpy as np
import jax.numpy as jnp
from jax import dtypes
from jax.interpreters.mlir import dtype_to_ir_type
......@@ -142,3 +146,24 @@ def get_cudnn_version() -> Tuple[int, int, int]:
major, encoded_version = divmod(encoded_version, major_version_magnitude)
minor, patch = divmod(encoded_version, 100)
return (major, minor, patch)
@functools.lru_cache(maxsize=None)
def jax_version_meet_requirement(version: str):
"""
Helper function checking if required JAX version is available
"""
jax_version = PkgVersion(get_pkg_version("jax"))
jax_version_required = PkgVersion(version)
return jax_version >= jax_version_required
def is_ffi_enabled():
"""
Helper function checking if XLA Custom Call with FFI is enabled
"""
is_supported = jax_version_meet_requirement("0.4.31")
# New APIs with FFI are enabled by default
is_enabled = int(os.getenv("NVTE_JAX_WITH_FFI", "1"))
assert is_enabled in (0, 1), "Invalid NVTE_JAX_WITH_FFI value"
return is_supported and is_enabled
......@@ -11,6 +11,7 @@ import jax.numpy as jnp
from jax import dtypes
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
from jax.extend import ffi
from transformer_engine import transformer_engine_jax
from transformer_engine.transformer_engine_jax import DType as TEDType
......@@ -25,6 +26,7 @@ from .misc import (
get_padded_spec,
multidim_transpose,
normalize_axis_boundary,
is_ffi_enabled,
)
from .activation import ActivationEnum
from .activation import _jax_act_lu
......@@ -262,6 +264,12 @@ class CastTransposePrimitive(BasePrimitive):
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
if is_ffi_enabled():
name = "te_cast_transpose_ffi"
out = ffi.ffi_lowering(name, operand_output_aliases={1: 2})(
ctx, x, amax, scale, scale_inv, transpose_axis=transpose_axis_boundary
)
else:
ir_x_type = ir.RankedTensorType(x.type)
ir_x_shape = ir_x_type.shape
if static_axis_boundary >= 0:
......@@ -296,11 +304,9 @@ class CastTransposePrimitive(BasePrimitive):
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(out_dtype),
)
out = custom_caller(
CastTransposePrimitive.name, args, opaque, False, operand_output_aliases={1: 2}
)
return out
@staticmethod
......
......@@ -13,8 +13,6 @@
#include <cudnn.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <transformer_engine/activation.h>
#include <transformer_engine/fused_attn.h>
#include <transformer_engine/transformer_engine.h>
#include <cassert>
......@@ -27,23 +25,14 @@
#include "common/common.h"
#include "common/util/logging.h"
#include "extensions/ffi.h"
#include "extensions/misc.h"
#include "transformer_engine/activation.h"
#include "utils.h"
namespace transformer_engine {
namespace jax {
constexpr int kMaxNumDim = 8;
// TODO: Rename Shape to ???
struct Shape {
int num_dim;
size_t dims[kMaxNumDim];
void from_vector(const std::vector<size_t> &shape);
std::vector<size_t> to_vector() const;
};
// Phuong: These 3 functions need to stay in the header file for compilation purpose
// 1.
inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; }
......@@ -62,8 +51,6 @@ const T *UnpackOpaque(const char *opaque, size_t opaque_len) {
return reinterpret_cast<const T *>(opaque);
}
std::vector<size_t> MakeShapeVector(NVTEShape shape);
// Packing
struct CustomCallCommonDescriptor {
......@@ -167,6 +154,8 @@ void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size
pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype);
XLA_FFI_DECLARE_HANDLER_SYMBOL(CastTransposeHandler);
void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
// Activation
......@@ -179,6 +168,10 @@ void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t op
void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuHandler);
pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype);
......
......@@ -3,15 +3,16 @@
*
* See LICENSE for license information.
************************************************************************/
#include "transformer_engine/activation.h"
#include "extensions.h"
#include "transformer_engine/transpose.h"
#include "xla/ffi/api/c_api.h"
namespace transformer_engine {
namespace jax {
// TODO: We won't need this function anymore when we move to the new XLA custom calls
size_t get_activation_len(NVTE_Activation_Type activation_enum) {
switch (activation_enum) {
case NVTE_Activation_Type::GELU:
......@@ -43,8 +44,7 @@ size_t get_activation_len(NVTE_Activation_Type activation_enum) {
void ActLuImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale,
cudaStream_t stream, float *scale_inverse, float *amax, void *output,
NVTE_Activation_Type act_enum) {
auto act_len = get_activation_len(act_enum);
NVTE_Activation_Type act_enum, size_t act_len) {
auto input_shape = std::vector<size_t>{m, n * act_len};
auto output_shape = std::vector<size_t>{m, n};
auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype));
......@@ -95,12 +95,39 @@ void ActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaqu
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);
;
auto act_len = get_activation_len(act_enum);
ActLuImpl(input, m, n, desc.in_dtype, desc.out_dtype, nullptr, stream, nullptr, nullptr, output,
act_enum);
act_enum, act_len);
}
Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type output_buf,
int64_t act_enum) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto *input = input_buf.untyped_data();
auto *output = output_buf->untyped_data();
auto input_dims = input_buf.dimensions();
auto m = std::accumulate(input_dims.begin(), input_dims.end() - 2, 1, std::multiplies<>());
auto n = input_dims.back();
auto act_len = input_dims.end()[-2];
auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
ActLuImpl(input, m, n, in_dtype, out_dtype, nullptr, stream, nullptr, nullptr, output, act_type,
act_len);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Ret<Buffer_Type>() // output
.Attr<int64_t>("act_enum"));
void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
float *amax = reinterpret_cast<float *>(buffers[1]);
......@@ -119,10 +146,10 @@ void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t op
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);
;
auto act_len = get_activation_len(act_enum);
ActLuImpl(input, m, n, desc.in_dtype, desc.out_dtype, scale, stream, scale_inv, amax_out, output,
act_enum);
act_enum, act_len);
}
void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
......@@ -134,7 +161,6 @@ void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaq
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto act_enum = static_cast<NVTE_Activation_Type>(desc.act_enum);
;
auto act_len = get_activation_len(act_enum);
auto input_shape = std::vector<size_t>{m, n};
......@@ -182,6 +208,76 @@ void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaq
}
}
Error_Type DActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type act_input_buf,
Result_Type output_buf, int64_t act_enum) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
auto *input = input_buf.untyped_data();
auto *act_input = act_input_buf.untyped_data();
auto *output = output_buf->untyped_data();
auto act_input_dims = act_input_buf.dimensions();
auto m =
std::accumulate(act_input_dims.begin(), act_input_dims.end() - 2, 1, std::multiplies<>());
auto n = act_input_dims.back();
auto act_len = act_input_dims.end()[-2];
auto input_shape = std::vector<size_t>{m, n};
auto act_input_shape = std::vector<size_t>{m, n * act_len};
auto output_shape = std::vector<size_t>{m, n * act_len};
auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype));
auto act_input_tensor = TensorWrapper(act_input, act_input_shape, static_cast<DType>(in_dtype));
auto output_tensor = TensorWrapper(output, output_shape, static_cast<DType>(out_dtype));
auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
switch (act_type) {
case NVTE_Activation_Type::GELU:
nvte_dgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::GEGLU:
nvte_dgeglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SILU:
nvte_dsilu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SWIGLU:
nvte_dswiglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::RELU:
nvte_drelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::REGLU:
nvte_dreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::QGELU:
nvte_dqgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::QGEGLU:
nvte_dqgeglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SRELU:
nvte_dsrelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break;
case NVTE_Activation_Type::SREGLU:
nvte_dsreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream);
break;
default:
NVTE_ERROR("Unsupported ActivationEnum");
break;
}
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuHandler, DActLuFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // act_input
.Ret<Buffer_Type>() // output
.Attr<int64_t>("act_enum"));
pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
......
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions/ffi.h"
#include <iostream>
#include "common/util/logging.h"
namespace transformer_engine {
namespace jax {
// For XLA_FFI_DataType Enum Reference: https://github.com/openxla/xla/blob/d054e8366c4e8807726961feeb28b1cdba681888/xla/ffi/api/c_api.h#L163-L186
DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type) {
switch (type) {
case xla::ffi::DataType::F16:
return DType::kFloat16;
break;
case xla::ffi::DataType::F32:
return DType::kFloat32;
break;
case xla::ffi::DataType::BF16:
return DType::kBFloat16;
break;
case xla::ffi::DataType::F8E5M2:
return DType::kFloat8E5M2;
break;
case xla::ffi::DataType::F8E4M3FN:
return DType::kFloat8E4M3;
break;
default:
auto type_num = static_cast<XLA_FFI_DataType>(type);
NVTE_ERROR("TE does not support conversion of XLA_FFI_DataType %d",
static_cast<int>(type_num));
break;
}
}
Error_Type ffi_with_cuda_error_check() {
cudaError_t last_error = cudaGetLastError();
if (last_error != cudaSuccess) {
return Error_Type(XLA_FFI_Error_Code_INTERNAL,
std::string("CUDA error: ") + cudaGetErrorString(last_error));
}
return Error_Type::Success();
}
} // namespace jax
} // namespace transformer_engine
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/transformer_engine.h>
#include <xla/ffi/api/ffi.h>
#include <numeric>
namespace transformer_engine {
namespace jax {
using Buffer_Type = xla::ffi::AnyBuffer;
using Result_Type = xla::ffi::Result<xla::ffi::AnyBuffer>;
using Error_Type = xla::ffi::Error;
using FFI = xla::ffi::Ffi;
using FFI_Stream_Type = xla::ffi::PlatformStream<cudaStream_t>;
DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type);
Error_Type ffi_with_cuda_error_check();
} // namespace jax
} // namespace transformer_engine
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/transformer_engine.h>
#include <cassert>
#include <string>
#include <vector>
namespace transformer_engine {
namespace jax {
constexpr int kMaxNumDim = 8;
struct Shape {
int num_dim;
size_t dims[kMaxNumDim];
void from_vector(const std::vector<size_t> &shape);
std::vector<size_t> to_vector() const;
};
std::vector<size_t> MakeShapeVector(NVTEShape shape);
} // namespace jax
} // namespace transformer_engine
......@@ -3,7 +3,6 @@
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
#include "transformer_engine/layer_norm.h"
#include "transformer_engine/rmsnorm.h"
......
......@@ -14,6 +14,13 @@ pybind11::capsule EncapsulateFunction(T *fn) {
return pybind11::capsule(reinterpret_cast<void *>(fn), "xla._CUSTOM_CALL_TARGET");
}
template <typename T>
pybind11::capsule EncapsulateFFI(T *fn) {
static_assert(std::is_invocable_r_v<XLA_FFI_Error *, T, XLA_FFI_CallFrame *>,
"Encapsulated function must be an XLA FFI handler");
return pybind11::capsule(reinterpret_cast<void *>(fn), "xla._CUSTOM_CALL_TARGET");
}
pybind11::dict Registrations() {
pybind11::dict dict;
dict["te_transpose"] = EncapsulateFunction(Transpose);
......@@ -44,6 +51,10 @@ pybind11::dict Registrations() {
EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackward);
dict["te_fused_attn_forward"] = EncapsulateFunction(FusedAttnForward);
dict["te_fused_attn_backward"] = EncapsulateFunction(FusedAttnBackward);
dict["te_cast_transpose_ffi"] = EncapsulateFFI(CastTransposeHandler);
dict["te_act_lu_ffi"] = EncapsulateFFI(ActLuHandler);
dict["te_dact_lu_ffi"] = EncapsulateFFI(DActLuHandler);
return dict;
}
......@@ -114,7 +125,8 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.value("QGELU", NVTE_Activation_Type::QGELU)
.value("QGEGLU", NVTE_Activation_Type::QGEGLU)
.value("SRELU", NVTE_Activation_Type::SRELU)
.value("SREGLU", NVTE_Activation_Type::SREGLU);
.value("SREGLU", NVTE_Activation_Type::SREGLU)
.export_values();
pybind11::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend", pybind11::module_local())
.value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend)
......
......@@ -7,6 +7,7 @@
#include "transformer_engine/transpose.h"
#include "extensions.h"
#include "xla/ffi/api/ffi.h"
namespace transformer_engine {
namespace jax {
......@@ -66,6 +67,61 @@ void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size
stream);
}
Error_Type CastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf,
Buffer_Type scale_buf, Buffer_Type scale_inv_buf,
Result_Type input_cast_buf, Result_Type input_cast_trans_buf,
Result_Type amax_out_buf, int64_t transpose_axis) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
auto out_dtype = convert_ffi_datatype_to_te_dtype(input_cast_buf->element_type());
auto *input = input_buf.untyped_data();
float *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
float *scale_inv = reinterpret_cast<float *>(scale_inv_buf.untyped_data());
auto *input_cast = input_cast_buf->untyped_data();
auto *input_cast_trans = input_cast_trans_buf->untyped_data();
float *amax_out = reinterpret_cast<float *>(amax_out_buf->untyped_data());
assert(amax == amax_out);
if (!use_fp8(out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto input_dims = input_buf.dimensions();
if (transpose_axis < 0) transpose_axis += input_dims.size();
auto m = std::accumulate(input_dims.begin(), input_dims.begin() + transpose_axis, 1,
std::multiplies<>());
auto n = std::accumulate(input_dims.begin() + transpose_axis, input_dims.end(), 1,
std::multiplies<>());
auto input_shape = std::vector<size_t>{m, n};
auto input_trans_shape = std::vector<size_t>{n, m};
auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
auto input_cast_tensor =
TensorWrapper(input_cast, input_shape, out_dtype, amax_out, scale, scale_inv);
auto input_cast_trans_tensor =
TensorWrapper(input_cast_trans, input_trans_shape, out_dtype, amax_out, scale, scale_inv);
nvte_cast_transpose(input_tensor.data(), input_cast_tensor.data(), input_cast_trans_tensor.data(),
stream);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(CastTransposeHandler, CastTransposeFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // input_cast
.Ret<Buffer_Type>() // input_cast_trans
.Ret<Buffer_Type>() // amax_out
.Attr<int64_t>("transpose_axis"));
pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment