Unverified Commit 26db7f34 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[C][Jax] Move cuda kernels from Jax extensions to core (#1697)



* Move jaxx cuda kernels to core
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 04c730c0
......@@ -7,7 +7,6 @@ import os
from pathlib import Path
import setuptools
from glob import glob
from .utils import cuda_path, all_files_in_dir
from typing import List
......@@ -41,9 +40,7 @@ def setup_jax_extension(
# Source files
csrc_source_files = Path(csrc_source_files)
extensions_dir = csrc_source_files / "extensions"
sources = [
csrc_source_files / "utils.cu",
] + all_files_in_dir(extensions_dir, ".cpp")
sources = all_files_in_dir(extensions_dir, ".cpp")
# Header files
cuda_home, _ = cuda_path()
......@@ -59,13 +56,12 @@ def setup_jax_extension(
# Compile flags
cxx_flags = ["-O3"]
nvcc_flags = ["-O3"]
# Define TE/JAX as a Pybind11Extension
from pybind11.setup_helpers import Pybind11Extension
class Pybind11CUDAExtension(Pybind11Extension):
"""Modified Pybind11Extension to allow combined CXX + NVCC compile flags."""
class Pybind11CPPExtension(Pybind11Extension):
"""Modified Pybind11Extension to allow custom CXX flags."""
def _add_cflags(self, flags: List[str]) -> None:
if isinstance(self.extra_compile_args, dict):
......@@ -75,9 +71,9 @@ def setup_jax_extension(
else:
self.extra_compile_args[:0] = flags
return Pybind11CUDAExtension(
return Pybind11CPPExtension(
"transformer_engine_jax",
sources=[str(path) for path in sources],
include_dirs=[str(path) for path in include_dirs],
extra_compile_args={"cxx": cxx_flags, "nvcc": nvcc_flags},
extra_compile_args={"cxx": cxx_flags},
)
......@@ -1006,3 +1006,18 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n");
}
}
uint32_t nvte_get_runtime_num_segments(NVTETensor cu_seqlen, NVTETensor workspace, size_t len,
cudaStream_t stream) {
NVTE_API_CALL(nvte_get_runtime_num_segments);
using namespace transformer_engine::fused_attn;
return GetRuntimeNumSegments(cu_seqlen, workspace, len, stream);
}
void nvte_populate_rng_state_async(NVTETensor rng_state_dst, const NVTETensor seed,
size_t q_max_seqlen, size_t kv_max_seqlen,
NVTE_Fused_Attn_Backend backend, cudaStream_t stream) {
NVTE_API_CALL(nvte_populate_rng_state_async);
using namespace transformer_engine::fused_attn;
PopulateRngStateAsync(rng_state_dst, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
}
......@@ -562,5 +562,53 @@ size_t get_max_tokens(size_t num_tokens) {
return max_t;
}
__global__ void populate_rng_state_kernel(int64_t *rng_state_dst, const int64_t *const seed,
int64_t offset) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid > 0) return;
rng_state_dst[0] = seed[0];
rng_state_dst[1] = offset;
}
__global__ void get_runtime_num_segments_kernel(int32_t *cu_seqlen, size_t len, uint32_t *out) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= len) return;
if (cu_seqlen[tid] > 0) {
// atomicAdd only support 32 bits dtype
atomicAdd(out, 1);
}
}
void PopulateRngStateAsync(void *rng_state_dst, const void *seed, size_t q_max_seqlen,
size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend,
cudaStream_t stream) {
size_t increment = 0;
if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
increment = 16;
} else {
constexpr int threads_per_cta = 128;
increment = (q_max_seqlen * kv_max_seqlen + threads_per_cta - 1) / threads_per_cta;
}
auto offset = FusedAttnOffsetManager::Instance().GetAndUpdateOffset(increment);
populate_rng_state_kernel<<<1, 1, 0, stream>>>(reinterpret_cast<int64_t *>(rng_state_dst),
reinterpret_cast<const int64_t *>(seed), offset);
NVTE_CHECK_CUDA(cudaGetLastError());
}
uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cudaStream_t stream) {
// workspace size requires 4 bytes
uint32_t *dout = static_cast<uint32_t *>(workspace);
uint32_t hout{};
cudaMemsetAsync(dout, 0, sizeof(uint32_t), stream);
constexpr int threads = 128;
const int blocks = (len - 1) / threads + 1;
get_runtime_num_segments_kernel<<<blocks, threads, 0, stream>>>(static_cast<int32_t *>(cu_seqlen),
len, dout);
cudaMemcpyAsync(&hout, dout, sizeof(uint32_t), cudaMemcpyDeviceToHost, stream);
cudaStreamSynchronize(stream);
return hout;
}
} // namespace fused_attn
} // namespace transformer_engine
......@@ -150,6 +150,38 @@ DType get_ragged_offset_dtype(NVTE_QKV_Layout_Group layout_group, int64_t num_at
size_t get_max_batch_size(size_t batch_size);
size_t get_max_tokens(size_t num_tokens);
class FusedAttnOffsetManager {
public:
static FusedAttnOffsetManager &Instance() {
static thread_local FusedAttnOffsetManager instance;
return instance;
}
size_t GetAndUpdateOffset(size_t increment) {
size_t ret = offset_;
offset_ += increment;
return ret;
}
FusedAttnOffsetManager(FusedAttnOffsetManager const &) = delete;
void operator=(FusedAttnOffsetManager const &) = delete;
private:
FusedAttnOffsetManager() {}
size_t offset_ = 0;
};
__global__ void populate_rng_state_kernel(int64_t *rng_state_dst, const int64_t *const seed,
int64_t offset);
__global__ void get_runtime_num_segments_kernel(int32_t *cu_seqlen, size_t len, uint32_t *out);
void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q_max_seqlen,
size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend,
cudaStream_t stream);
uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cudaStream_t stream);
} // namespace fused_attn
} // namespace transformer_engine
......
......@@ -580,6 +580,31 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
int64_t window_size_right, bool deterministic, NVTETensor workspace,
cudaStream_t stream);
/*! \brief Update the RNG state with the seed and calculated offset.
*
* \param[in] rng_state_dst RNG state to store seed and offset.
* \param[in] seed Seed for RNG state.
* \param[in] q_max_seqlen Max sequence length used for computing for Q.
* it may be >= max(seqlen_q_i) for i=0,...batch_size-1.
* \param[in] kv_max_seqlen Max sequence length used for computing for K and V.
* it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
* \param[in] backend Fused attention backend.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_populate_rng_state_async(NVTETensor rng_state_dst, const NVTETensor seed,
size_t q_max_seqlen, size_t kv_max_seqlen,
NVTE_Fused_Attn_Backend backend, cudaStream_t stream);
/*! \brief Get KV format for a given QKV layout.
*
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[in] workspace Workspace tensor.
* \param[in] len batch_size x sequence_length.
* \param[in] stream CUDA stream used for this operation.
*/
uint32_t nvte_get_runtime_num_segments(NVTETensor cu_seqlen, NVTETensor workspace, size_t len,
cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
......
......@@ -28,8 +28,8 @@
#include "common/util/logging.h"
#include "extensions/ffi.h"
#include "extensions/misc.h"
#include "extensions/utils.h"
#include "transformer_engine/activation.h"
#include "utils.h"
// ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode);
......
......@@ -196,10 +196,10 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
if (cudnn_runtime_version >= 90300) { \
num_segments = input_batch * max_segments_per_seq; \
} else { \
size_t runtime_num_segments_q = \
GetRuntimeNumSegments(q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream); \
size_t runtime_num_segments_kv = \
GetRuntimeNumSegments(kv_cu_seqlens, workspace, input_batch * kv_max_seqlen, stream); \
size_t runtime_num_segments_q = nvte_get_runtime_num_segments( \
q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream); \
size_t runtime_num_segments_kv = nvte_get_runtime_num_segments( \
kv_cu_seqlens, workspace, input_batch * kv_max_seqlen, stream); \
NVTE_CHECK(runtime_num_segments_q == runtime_num_segments_kv); \
NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq); \
num_segments = runtime_num_segments_q; \
......@@ -248,7 +248,7 @@ static void FusedAttnForwardImpl(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
head_dim, head_dim, window_size_left, window_size_right);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
/* Auxiliary tensors (to be propagated to the backward pass later) */
NVTETensorPack aux_output_tensors;
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "utils.h"
#include <cuda_runtime_api.h>
#include <cassert>
#include "common/util/cuda_runtime.h"
namespace transformer_engine {
namespace jax {
int GetCudaRuntimeVersion() {
int ver = 0;
NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&ver));
return ver;
}
size_t GetCudnnRuntimeVersion() { return cudnnGetVersion(); }
int GetDeviceComputeCapability(int gpu_id) { return transformer_engine::cuda::sm_arch(gpu_id); }
} // namespace jax
} // namespace transformer_engine
......@@ -4,9 +4,6 @@
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_JAX_CSRC_UTILS_H_
#define TRANSFORMER_ENGINE_JAX_CSRC_UTILS_H_
#include <pybind11/pybind11.h>
#include <transformer_engine/fused_attn.h>
......@@ -25,12 +22,6 @@ int GetCudaRuntimeVersion();
size_t GetCudnnRuntimeVersion();
int GetDeviceComputeCapability(int gpu_id);
void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q_max_seqlen,
size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend,
cudaStream_t stream);
uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cudaStream_t stream);
class cudaDevicePropertiesManager {
public:
static cudaDevicePropertiesManager &Instance() {
......@@ -63,28 +54,5 @@ class cudaDevicePropertiesManager {
cudaDeviceProp prop_;
};
class FusedAttnOffsetManager {
public:
static FusedAttnOffsetManager &Instance() {
static thread_local FusedAttnOffsetManager instance;
return instance;
}
size_t GetAndUpdateOffset(size_t increment) {
size_t ret = offset_;
offset_ += increment;
return ret;
}
FusedAttnOffsetManager(FusedAttnOffsetManager const &) = delete;
void operator=(FusedAttnOffsetManager const &) = delete;
private:
FusedAttnOffsetManager() {}
size_t offset_ = 0;
};
} // namespace jax
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_JAX_CSRC_UTILS_H_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_runtime_api.h>
#include <cassert>
#include "common/util/cuda_runtime.h"
#include "utils.h"
namespace transformer_engine {
namespace jax {
int GetCudaRuntimeVersion() {
int ver = 0;
NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&ver));
return ver;
}
size_t GetCudnnRuntimeVersion() { return cudnnGetVersion(); }
int GetDeviceComputeCapability(int gpu_id) { return transformer_engine::cuda::sm_arch(gpu_id); }
__global__ void populate_rng_state_kernel(int64_t *rng_state_dst, const int64_t *const seed,
int64_t offset) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid > 0) return;
rng_state_dst[0] = seed[0];
rng_state_dst[1] = offset;
}
void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q_max_seqlen,
size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend,
cudaStream_t stream) {
size_t increment = 0;
if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
increment = 16;
} else {
constexpr int threads_per_cta = 128;
increment = (q_max_seqlen * kv_max_seqlen + threads_per_cta - 1) / threads_per_cta;
}
auto offset = FusedAttnOffsetManager::Instance().GetAndUpdateOffset(increment);
populate_rng_state_kernel<<<1, 1, 0, stream>>>(reinterpret_cast<int64_t *>(rng_state_dst),
reinterpret_cast<const int64_t *>(seed), offset);
NVTE_CHECK_CUDA(cudaGetLastError());
}
__global__ void get_runtime_num_segments_kernel(int32_t *cu_seqlen, size_t len, uint32_t *out) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= len) return;
if (cu_seqlen[tid] > 0) {
// atomicAdd only support 32 bits dtype
atomicAdd(out, 1);
}
}
uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cudaStream_t stream) {
// workspace size requires 4 bytes
uint32_t *dout = static_cast<uint32_t *>(workspace);
uint32_t hout{};
cudaMemsetAsync(dout, 0, sizeof(uint32_t), stream);
constexpr int threads = 128;
const int blocks = (len - 1) / threads + 1;
get_runtime_num_segments_kernel<<<blocks, threads, 0, stream>>>(static_cast<int32_t *>(cu_seqlen),
len, dout);
cudaMemcpyAsync(&hout, dout, sizeof(uint32_t), cudaMemcpyDeviceToHost, stream);
cudaStreamSynchronize(stream);
return hout;
}
} // namespace jax
} // namespace transformer_engine
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment