Unverified Commit 26db7f34 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[C][Jax] Move cuda kernels from Jax extensions to core (#1697)



* Move jaxx cuda kernels to core
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 04c730c0
...@@ -7,7 +7,6 @@ import os ...@@ -7,7 +7,6 @@ import os
from pathlib import Path from pathlib import Path
import setuptools import setuptools
from glob import glob
from .utils import cuda_path, all_files_in_dir from .utils import cuda_path, all_files_in_dir
from typing import List from typing import List
...@@ -41,9 +40,7 @@ def setup_jax_extension( ...@@ -41,9 +40,7 @@ def setup_jax_extension(
# Source files # Source files
csrc_source_files = Path(csrc_source_files) csrc_source_files = Path(csrc_source_files)
extensions_dir = csrc_source_files / "extensions" extensions_dir = csrc_source_files / "extensions"
sources = [ sources = all_files_in_dir(extensions_dir, ".cpp")
csrc_source_files / "utils.cu",
] + all_files_in_dir(extensions_dir, ".cpp")
# Header files # Header files
cuda_home, _ = cuda_path() cuda_home, _ = cuda_path()
...@@ -59,13 +56,12 @@ def setup_jax_extension( ...@@ -59,13 +56,12 @@ def setup_jax_extension(
# Compile flags # Compile flags
cxx_flags = ["-O3"] cxx_flags = ["-O3"]
nvcc_flags = ["-O3"]
# Define TE/JAX as a Pybind11Extension # Define TE/JAX as a Pybind11Extension
from pybind11.setup_helpers import Pybind11Extension from pybind11.setup_helpers import Pybind11Extension
class Pybind11CUDAExtension(Pybind11Extension): class Pybind11CPPExtension(Pybind11Extension):
"""Modified Pybind11Extension to allow combined CXX + NVCC compile flags.""" """Modified Pybind11Extension to allow custom CXX flags."""
def _add_cflags(self, flags: List[str]) -> None: def _add_cflags(self, flags: List[str]) -> None:
if isinstance(self.extra_compile_args, dict): if isinstance(self.extra_compile_args, dict):
...@@ -75,9 +71,9 @@ def setup_jax_extension( ...@@ -75,9 +71,9 @@ def setup_jax_extension(
else: else:
self.extra_compile_args[:0] = flags self.extra_compile_args[:0] = flags
return Pybind11CUDAExtension( return Pybind11CPPExtension(
"transformer_engine_jax", "transformer_engine_jax",
sources=[str(path) for path in sources], sources=[str(path) for path in sources],
include_dirs=[str(path) for path in include_dirs], include_dirs=[str(path) for path in include_dirs],
extra_compile_args={"cxx": cxx_flags, "nvcc": nvcc_flags}, extra_compile_args={"cxx": cxx_flags},
) )
...@@ -1006,3 +1006,18 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -1006,3 +1006,18 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n");
} }
} }
uint32_t nvte_get_runtime_num_segments(NVTETensor cu_seqlen, NVTETensor workspace, size_t len,
cudaStream_t stream) {
NVTE_API_CALL(nvte_get_runtime_num_segments);
using namespace transformer_engine::fused_attn;
return GetRuntimeNumSegments(cu_seqlen, workspace, len, stream);
}
void nvte_populate_rng_state_async(NVTETensor rng_state_dst, const NVTETensor seed,
size_t q_max_seqlen, size_t kv_max_seqlen,
NVTE_Fused_Attn_Backend backend, cudaStream_t stream) {
NVTE_API_CALL(nvte_populate_rng_state_async);
using namespace transformer_engine::fused_attn;
PopulateRngStateAsync(rng_state_dst, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
}
...@@ -562,5 +562,53 @@ size_t get_max_tokens(size_t num_tokens) { ...@@ -562,5 +562,53 @@ size_t get_max_tokens(size_t num_tokens) {
return max_t; return max_t;
} }
__global__ void populate_rng_state_kernel(int64_t *rng_state_dst, const int64_t *const seed,
int64_t offset) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid > 0) return;
rng_state_dst[0] = seed[0];
rng_state_dst[1] = offset;
}
__global__ void get_runtime_num_segments_kernel(int32_t *cu_seqlen, size_t len, uint32_t *out) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= len) return;
if (cu_seqlen[tid] > 0) {
// atomicAdd only support 32 bits dtype
atomicAdd(out, 1);
}
}
void PopulateRngStateAsync(void *rng_state_dst, const void *seed, size_t q_max_seqlen,
size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend,
cudaStream_t stream) {
size_t increment = 0;
if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
increment = 16;
} else {
constexpr int threads_per_cta = 128;
increment = (q_max_seqlen * kv_max_seqlen + threads_per_cta - 1) / threads_per_cta;
}
auto offset = FusedAttnOffsetManager::Instance().GetAndUpdateOffset(increment);
populate_rng_state_kernel<<<1, 1, 0, stream>>>(reinterpret_cast<int64_t *>(rng_state_dst),
reinterpret_cast<const int64_t *>(seed), offset);
NVTE_CHECK_CUDA(cudaGetLastError());
}
uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cudaStream_t stream) {
// workspace size requires 4 bytes
uint32_t *dout = static_cast<uint32_t *>(workspace);
uint32_t hout{};
cudaMemsetAsync(dout, 0, sizeof(uint32_t), stream);
constexpr int threads = 128;
const int blocks = (len - 1) / threads + 1;
get_runtime_num_segments_kernel<<<blocks, threads, 0, stream>>>(static_cast<int32_t *>(cu_seqlen),
len, dout);
cudaMemcpyAsync(&hout, dout, sizeof(uint32_t), cudaMemcpyDeviceToHost, stream);
cudaStreamSynchronize(stream);
return hout;
}
} // namespace fused_attn } // namespace fused_attn
} // namespace transformer_engine } // namespace transformer_engine
...@@ -150,6 +150,38 @@ DType get_ragged_offset_dtype(NVTE_QKV_Layout_Group layout_group, int64_t num_at ...@@ -150,6 +150,38 @@ DType get_ragged_offset_dtype(NVTE_QKV_Layout_Group layout_group, int64_t num_at
size_t get_max_batch_size(size_t batch_size); size_t get_max_batch_size(size_t batch_size);
size_t get_max_tokens(size_t num_tokens); size_t get_max_tokens(size_t num_tokens);
class FusedAttnOffsetManager {
public:
static FusedAttnOffsetManager &Instance() {
static thread_local FusedAttnOffsetManager instance;
return instance;
}
size_t GetAndUpdateOffset(size_t increment) {
size_t ret = offset_;
offset_ += increment;
return ret;
}
FusedAttnOffsetManager(FusedAttnOffsetManager const &) = delete;
void operator=(FusedAttnOffsetManager const &) = delete;
private:
FusedAttnOffsetManager() {}
size_t offset_ = 0;
};
__global__ void populate_rng_state_kernel(int64_t *rng_state_dst, const int64_t *const seed,
int64_t offset);
__global__ void get_runtime_num_segments_kernel(int32_t *cu_seqlen, size_t len, uint32_t *out);
void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q_max_seqlen,
size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend,
cudaStream_t stream);
uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cudaStream_t stream);
} // namespace fused_attn } // namespace fused_attn
} // namespace transformer_engine } // namespace transformer_engine
......
...@@ -580,6 +580,31 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -580,6 +580,31 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
int64_t window_size_right, bool deterministic, NVTETensor workspace, int64_t window_size_right, bool deterministic, NVTETensor workspace,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Update the RNG state with the seed and calculated offset.
*
* \param[in] rng_state_dst RNG state to store seed and offset.
* \param[in] seed Seed for RNG state.
* \param[in] q_max_seqlen Max sequence length used for computing for Q.
* it may be >= max(seqlen_q_i) for i=0,...batch_size-1.
* \param[in] kv_max_seqlen Max sequence length used for computing for K and V.
* it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
* \param[in] backend Fused attention backend.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_populate_rng_state_async(NVTETensor rng_state_dst, const NVTETensor seed,
size_t q_max_seqlen, size_t kv_max_seqlen,
NVTE_Fused_Attn_Backend backend, cudaStream_t stream);
/*! \brief Get KV format for a given QKV layout.
*
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[in] workspace Workspace tensor.
* \param[in] len batch_size x sequence_length.
* \param[in] stream CUDA stream used for this operation.
*/
uint32_t nvte_get_runtime_num_segments(NVTETensor cu_seqlen, NVTETensor workspace, size_t len,
cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
......
...@@ -28,8 +28,8 @@ ...@@ -28,8 +28,8 @@
#include "common/util/logging.h" #include "common/util/logging.h"
#include "extensions/ffi.h" #include "extensions/ffi.h"
#include "extensions/misc.h" #include "extensions/misc.h"
#include "extensions/utils.h"
#include "transformer_engine/activation.h" #include "transformer_engine/activation.h"
#include "utils.h"
// ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace // ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode); XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode);
......
...@@ -187,31 +187,31 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ...@@ -187,31 +187,31 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
return pybind11::make_tuple(workspace_shape, query_workspace_tensor.dtype()); return pybind11::make_tuple(workspace_shape, query_workspace_tensor.dtype());
} }
#define FUSED_ATTN_IMPL_COMMON_BLOCK \ #define FUSED_ATTN_IMPL_COMMON_BLOCK \
auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; \ auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; \
auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; \ auto bias_shape = std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; \
size_t num_segments = input_batch; \ size_t num_segments = input_batch; \
if (is_ragged) { \ if (is_ragged) { \
auto cudnn_runtime_version = cudnnGetVersion(); \ auto cudnn_runtime_version = cudnnGetVersion(); \
if (cudnn_runtime_version >= 90300) { \ if (cudnn_runtime_version >= 90300) { \
num_segments = input_batch * max_segments_per_seq; \ num_segments = input_batch * max_segments_per_seq; \
} else { \ } else { \
size_t runtime_num_segments_q = \ size_t runtime_num_segments_q = nvte_get_runtime_num_segments( \
GetRuntimeNumSegments(q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream); \ q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream); \
size_t runtime_num_segments_kv = \ size_t runtime_num_segments_kv = nvte_get_runtime_num_segments( \
GetRuntimeNumSegments(kv_cu_seqlens, workspace, input_batch * kv_max_seqlen, stream); \ kv_cu_seqlens, workspace, input_batch * kv_max_seqlen, stream); \
NVTE_CHECK(runtime_num_segments_q == runtime_num_segments_kv); \ NVTE_CHECK(runtime_num_segments_q == runtime_num_segments_kv); \
NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq); \ NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq); \
num_segments = runtime_num_segments_q; \ num_segments = runtime_num_segments_q; \
} \ } \
} \ } \
std::vector<size_t> seq_shape{num_segments + 1}; \ std::vector<size_t> seq_shape{num_segments + 1}; \
auto q_cu_seqlens_tensor = TensorWrapper(q_cu_seqlens, seq_shape, DType::kInt32); \ auto q_cu_seqlens_tensor = TensorWrapper(q_cu_seqlens, seq_shape, DType::kInt32); \
auto kv_cu_seqlens_tensor = TensorWrapper(kv_cu_seqlens, seq_shape, DType::kInt32); \ auto kv_cu_seqlens_tensor = TensorWrapper(kv_cu_seqlens, seq_shape, DType::kInt32); \
auto q_seq_offsets_tensor = TensorWrapper(q_seq_offsets, seq_shape, DType::kInt32); \ auto q_seq_offsets_tensor = TensorWrapper(q_seq_offsets, seq_shape, DType::kInt32); \
auto k_seq_offsets_tensor = TensorWrapper(k_seq_offsets, seq_shape, DType::kInt32); \ auto k_seq_offsets_tensor = TensorWrapper(k_seq_offsets, seq_shape, DType::kInt32); \
auto workspace_tensor = \ auto workspace_tensor = \
TensorWrapper(workspace, std::vector<size_t>{wkspace_size}, wkspace_dtype); \ TensorWrapper(workspace, std::vector<size_t>{wkspace_size}, wkspace_dtype); \
auto layout_group = nvte_get_qkv_layout_group(qkv_layout); auto layout_group = nvte_get_qkv_layout_group(qkv_layout);
static void FusedAttnForwardImpl( static void FusedAttnForwardImpl(
...@@ -248,7 +248,7 @@ static void FusedAttnForwardImpl( ...@@ -248,7 +248,7 @@ static void FusedAttnForwardImpl(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
head_dim, head_dim, window_size_left, window_size_right); head_dim, head_dim, window_size_left, window_size_right);
PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
/* Auxiliary tensors (to be propagated to the backward pass later) */ /* Auxiliary tensors (to be propagated to the backward pass later) */
NVTETensorPack aux_output_tensors; NVTETensorPack aux_output_tensors;
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "utils.h"
#include <cuda_runtime_api.h>
#include <cassert>
#include "common/util/cuda_runtime.h"
namespace transformer_engine {
namespace jax {
int GetCudaRuntimeVersion() {
int ver = 0;
NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&ver));
return ver;
}
size_t GetCudnnRuntimeVersion() { return cudnnGetVersion(); }
int GetDeviceComputeCapability(int gpu_id) { return transformer_engine::cuda::sm_arch(gpu_id); }
} // namespace jax
} // namespace transformer_engine
...@@ -4,9 +4,6 @@ ...@@ -4,9 +4,6 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#ifndef TRANSFORMER_ENGINE_JAX_CSRC_UTILS_H_
#define TRANSFORMER_ENGINE_JAX_CSRC_UTILS_H_
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <transformer_engine/fused_attn.h> #include <transformer_engine/fused_attn.h>
...@@ -25,12 +22,6 @@ int GetCudaRuntimeVersion(); ...@@ -25,12 +22,6 @@ int GetCudaRuntimeVersion();
size_t GetCudnnRuntimeVersion(); size_t GetCudnnRuntimeVersion();
int GetDeviceComputeCapability(int gpu_id); int GetDeviceComputeCapability(int gpu_id);
void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q_max_seqlen,
size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend,
cudaStream_t stream);
uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cudaStream_t stream);
class cudaDevicePropertiesManager { class cudaDevicePropertiesManager {
public: public:
static cudaDevicePropertiesManager &Instance() { static cudaDevicePropertiesManager &Instance() {
...@@ -63,28 +54,5 @@ class cudaDevicePropertiesManager { ...@@ -63,28 +54,5 @@ class cudaDevicePropertiesManager {
cudaDeviceProp prop_; cudaDeviceProp prop_;
}; };
class FusedAttnOffsetManager {
public:
static FusedAttnOffsetManager &Instance() {
static thread_local FusedAttnOffsetManager instance;
return instance;
}
size_t GetAndUpdateOffset(size_t increment) {
size_t ret = offset_;
offset_ += increment;
return ret;
}
FusedAttnOffsetManager(FusedAttnOffsetManager const &) = delete;
void operator=(FusedAttnOffsetManager const &) = delete;
private:
FusedAttnOffsetManager() {}
size_t offset_ = 0;
};
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_JAX_CSRC_UTILS_H_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_runtime_api.h>
#include <cassert>
#include "common/util/cuda_runtime.h"
#include "utils.h"
namespace transformer_engine {
namespace jax {
int GetCudaRuntimeVersion() {
int ver = 0;
NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&ver));
return ver;
}
size_t GetCudnnRuntimeVersion() { return cudnnGetVersion(); }
int GetDeviceComputeCapability(int gpu_id) { return transformer_engine::cuda::sm_arch(gpu_id); }
__global__ void populate_rng_state_kernel(int64_t *rng_state_dst, const int64_t *const seed,
int64_t offset) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid > 0) return;
rng_state_dst[0] = seed[0];
rng_state_dst[1] = offset;
}
void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q_max_seqlen,
size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend,
cudaStream_t stream) {
size_t increment = 0;
if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
increment = 16;
} else {
constexpr int threads_per_cta = 128;
increment = (q_max_seqlen * kv_max_seqlen + threads_per_cta - 1) / threads_per_cta;
}
auto offset = FusedAttnOffsetManager::Instance().GetAndUpdateOffset(increment);
populate_rng_state_kernel<<<1, 1, 0, stream>>>(reinterpret_cast<int64_t *>(rng_state_dst),
reinterpret_cast<const int64_t *>(seed), offset);
NVTE_CHECK_CUDA(cudaGetLastError());
}
__global__ void get_runtime_num_segments_kernel(int32_t *cu_seqlen, size_t len, uint32_t *out) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= len) return;
if (cu_seqlen[tid] > 0) {
// atomicAdd only support 32 bits dtype
atomicAdd(out, 1);
}
}
uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cudaStream_t stream) {
// workspace size requires 4 bytes
uint32_t *dout = static_cast<uint32_t *>(workspace);
uint32_t hout{};
cudaMemsetAsync(dout, 0, sizeof(uint32_t), stream);
constexpr int threads = 128;
const int blocks = (len - 1) / threads + 1;
get_runtime_num_segments_kernel<<<blocks, threads, 0, stream>>>(static_cast<int32_t *>(cu_seqlen),
len, dout);
cudaMemcpyAsync(&hout, dout, sizeof(uint32_t), cudaMemcpyDeviceToHost, stream);
cudaStreamSynchronize(stream);
return hout;
}
} // namespace jax
} // namespace transformer_engine
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment