cudnn.cpp 851 Bytes
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
 *
 * See LICENSE for license information.
 ************************************************************************/

#include "transformer_engine/cudnn.h"

9
#include "../extensions.h"
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#include "xla/ffi/api/c_api.h"

namespace transformer_engine {
namespace jax {

Error_Type CudnnHandleInitFFI(Variadic_Buffer_Type args, Variadic_Result_Type rets,
                              Dictionary attrs) {
  nvte_cudnn_handle_init();
  return ffi_with_cuda_error_check();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(CudnnHandleInitHandler, CudnnHandleInitFFI,
                              FFI::Bind<FFI_Prepare>().RemainingArgs().RemainingRets().Attrs());
}  // namespace jax
}  // namespace transformer_engine