Unverified Commit 12f30ead authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[TE/JAX] Enabling CudaGraph for custom calls with FFI (#1228)



* register CmdBufferCompatible traits via C++ API

* renamed FFI_Traits

* use register_ffi_target()

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 8e97c8da
......@@ -5,8 +5,8 @@
from dataclasses import dataclass
from enum import IntEnum
from jax.lib import xla_client
from jax.interpreters import mlir
import jax.extend as jex
from transformer_engine import transformer_engine_jax
......@@ -30,12 +30,11 @@ class CustomCallAPIVersion(IntEnum):
for _name, _value in transformer_engine_jax.registrations().items():
if _name.endswith("_ffi"):
if is_ffi_enabled():
# COMMAND_BUFFER_COMPATIBLE i.e. cudaGraph enabled by default
xla_client.register_custom_call_target(
jex.ffi.register_ffi_target(
_name, _value, platform="CUDA", api_version=CustomCallAPIVersion.FFI.value
)
else:
xla_client.register_custom_call_target(
jex.ffi.register_ffi_target(
_name, _value, platform="CUDA", api_version=CustomCallAPIVersion.OPAQUE.value
)
......
......@@ -126,7 +126,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI,
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Ret<Buffer_Type>() // output
.Attr<int64_t>("act_enum"));
.Attr<int64_t>("act_enum"),
FFI_CudaGraph_Traits);
void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
......@@ -276,7 +277,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuHandler, DActLuFFI,
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // act_input
.Ret<Buffer_Type>() // output
.Attr<int64_t>("act_enum"));
.Attr<int64_t>("act_enum"),
FFI_CudaGraph_Traits);
pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype) {
......
......@@ -17,6 +17,7 @@ using Result_Type = xla::ffi::Result<xla::ffi::AnyBuffer>;
using Error_Type = xla::ffi::Error;
using FFI = xla::ffi::Ffi;
using FFI_Stream_Type = xla::ffi::PlatformStream<cudaStream_t>;
constexpr auto FFI_CudaGraph_Traits = {xla::ffi::Traits::kCmdBufferCompatible};
DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type);
Error_Type ffi_with_cuda_error_check();
......
......@@ -120,7 +120,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(CastTransposeHandler, CastTransposeFFI,
.Ret<Buffer_Type>() // input_cast
.Ret<Buffer_Type>() // input_cast_trans
.Ret<Buffer_Type>() // amax_out
.Attr<int64_t>("transpose_axis"));
.Attr<int64_t>("transpose_axis"),
FFI_CudaGraph_Traits);
pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment