Unverified Commit 51cd4415 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[C][PyTorch]Make pytorch extensions pure cpp (#1754)



* First pass refactor
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* first pass
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* core compiles
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Include cuda dirs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Compiles
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Move grad outside autocast
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix kv cache
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Address review comments
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Change src file name in cmake
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* move the kernels too
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Move comment
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Move comments around
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* more movement
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* move
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent dfe1a65a
......@@ -8,11 +8,7 @@ from pathlib import Path
import setuptools
from .utils import (
all_files_in_dir,
cuda_archs,
cuda_version,
)
from .utils import all_files_in_dir, cuda_version, get_cuda_include_dirs
def setup_pytorch_extension(
......@@ -30,55 +26,30 @@ def setup_pytorch_extension(
] + all_files_in_dir(extensions_dir)
# Header files
include_dirs = [
include_dirs = get_cuda_include_dirs()
include_dirs.extend(
[
common_header_files,
common_header_files / "common",
common_header_files / "common" / "include",
csrc_header_files,
]
)
# Compiler flags
cxx_flags = [
"-O3",
"-fvisibility=hidden",
]
nvcc_flags = [
"-O3",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
]
cuda_architectures = cuda_archs()
if "70" in cuda_architectures:
nvcc_flags.extend(["-gencode", "arch=compute_70,code=sm_70"])
# Version-dependent CUDA options
try:
version = cuda_version()
except FileNotFoundError:
print("Could not determine CUDA Toolkit version")
print("Could not determine CUDA version")
else:
if version < (12, 0):
raise RuntimeError("Transformer Engine requires CUDA 12.0 or newer")
nvcc_flags.extend(
(
"--threads",
os.getenv("NVTE_BUILD_THREADS_PER_JOB", "1"),
)
)
for arch in cuda_architectures.split(";"):
if arch == "70":
continue # Already handled
nvcc_flags.extend(["-gencode", f"arch=compute_{arch},code=sm_{arch}"])
if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
assert (
......@@ -87,7 +58,6 @@ def setup_pytorch_extension(
mpi_path = Path(os.getenv("MPI_HOME"))
include_dirs.append(mpi_path / "include")
cxx_flags.append("-DNVTE_UB_WITH_MPI")
nvcc_flags.append("-DNVTE_UB_WITH_MPI")
library_dirs = []
libraries = []
......@@ -100,21 +70,17 @@ def setup_pytorch_extension(
library_dirs.append(nvshmem_home / "lib")
libraries.append("nvshmem_host")
cxx_flags.append("-DNVTE_ENABLE_NVSHMEM")
nvcc_flags.append("-DNVTE_ENABLE_NVSHMEM")
# Construct PyTorch CUDA extension
sources = [str(path) for path in sources]
include_dirs = [str(path) for path in include_dirs]
from torch.utils.cpp_extension import CUDAExtension
from torch.utils.cpp_extension import CppExtension
return CUDAExtension(
return CppExtension(
name="transformer_engine_torch",
sources=[str(src) for src in sources],
include_dirs=[str(inc) for inc in include_dirs],
extra_compile_args={
"cxx": cxx_flags,
"nvcc": nvcc_flags,
},
extra_compile_args={"cxx": cxx_flags},
libraries=[str(lib) for lib in libraries],
library_dirs=[str(lib_dir) for lib_dir in library_dirs],
)
......@@ -130,18 +130,20 @@ def dtype_tols(dtype: torch.dtype) -> Dict[str, float]:
def assert_allclose(
l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float, rtol: float = None
l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float = None, rtol: float = None
) -> bool:
"""Ensures two lists are equal."""
assert len(l1) == len(l2), "Unequal number of outputs."
for i, (t1, t2) in enumerate(zip(l1, l2)):
tols = dict(atol=atol)
tols = dtype_tols(t2.dtype)
if rtol is not None:
tols["rtol"] = rtol
if atol is not None:
tols["atol"] = atol
result = torch.allclose(t1, t2, **tols)
if not result:
diff = torch.abs(t1 - t2)
tol = atol + (rtol * torch.abs(t2))
tol = tols["atol"] + (tols["rtol"] * torch.abs(t2))
exceed_mask = diff > tol
if exceed_mask.any():
indices = torch.nonzero(exceed_mask, as_tuple=True)
......
......@@ -66,6 +66,9 @@ list(APPEND transformer_engine_SOURCES
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise.cu
activation/gelu.cu
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
fused_attn/kv_cache.cu
fused_attn/fused_attn_f16_max512_seqlen.cu
fused_attn/fused_attn_f16_arbitrary_seqlen.cu
activation/relu.cu
......@@ -173,6 +176,9 @@ set_source_files_properties(fused_softmax/scaled_masked_softmax.cu
multi_tensor/l2norm.cu
multi_tensor/scale.cu
multi_tensor/sgd.cu
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
fused_attn/kv_cache.cu
PROPERTIES
COMPILE_OPTIONS "--use_fast_math")
option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --use_fast_math option" OFF)
......
......@@ -111,7 +111,7 @@ struct Tensor {
columnwise_scale_inv(nullptr, {1}, DType::kFloat32),
scaling_mode(NVTE_DELAYED_TENSOR_SCALING) {}
int numel() const {
size_t numel() const {
size_t acc = 1;
for (const auto dim : shape()) {
acc *= dim;
......@@ -133,6 +133,14 @@ struct Tensor {
return data.dtype;
}
size_t dim() const {
if (!has_data() && has_columnwise_data()) {
return columnwise_data.shape.size();
} else {
return data.shape.size();
}
}
std::vector<size_t> shape() const {
/* Note: We sometimes experience spurious compiler errors
* (-Wstringop-overflow) from this function. It appears that GCC
......@@ -385,6 +393,33 @@ struct TypeInfo {
NVTE_ERROR("Invalid type."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FLOAT(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
case DType::kFloat32: { \
using type = float; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat16: { \
using type = fp16; \
{ __VA_ARGS__ } \
} break; \
case DType::kBFloat16: { \
using type = bf16; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat8E4M3: { \
using type = fp8e4m3; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat8E5M2: { \
using type = fp8e5m2; \
{ __VA_ARGS__ } \
} break; \
default: \
NVTE_ERROR("Invalid type."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "../common.h"
#include "transformer_engine/fused_attn.h"
namespace transformer_engine {
namespace flash_attention {
constexpr int warp_size = 32;
constexpr int type_size = 2; // FP16 or BF16
constexpr int nvec = sizeof(uint64_t) / type_size;
constexpr int load_size = warp_size * nvec;
constexpr int block_size = 512;
template <typename T>
__launch_bounds__(block_size) __global__
void prepare_kernel_fwd(const T *qkvi, T *qkv, const size_t B, const size_t S, const size_t Z,
const size_t W) {
const int warpid = (blockDim.x * blockIdx.x + threadIdx.x) / warp_size;
const int id_in_warp = threadIdx.x % warp_size;
const size_t offset_input = blockIdx.y * W + warpid * 3 * W * Z + id_in_warp * nvec;
const T *my_input = qkvi + offset_input;
const size_t s = warpid / B;
if (s >= S) return;
const size_t b = warpid % B;
const size_t offset_output = blockIdx.y * B * S * Z * W + (s + b * S) * W * Z + id_in_warp * nvec;
T *my_output = qkv + offset_output;
for (int i = 0; i < Z; ++i) {
uint64_t *out = reinterpret_cast<uint64_t *>(my_output + i * load_size);
*out = *reinterpret_cast<const uint64_t *>(my_input + i * load_size * 3);
}
}
template <typename T>
__launch_bounds__(block_size) __global__
void prepare_kernel_bwd(const T *q, const T *k, const T *v, T *qkv, const size_t B,
const size_t S, const size_t Z, const size_t W) {
const T *input = blockIdx.y == 0 ? q : (blockIdx.y == 1 ? k : v);
const int warpid = (blockDim.x * blockIdx.x + threadIdx.x) / warp_size;
const int id_in_warp = threadIdx.x % warp_size;
const size_t offset_input = warpid * W * Z + id_in_warp * nvec;
const T *my_input = input + offset_input;
const size_t b = warpid / S;
if (b >= B) return;
const size_t s = warpid % S;
const size_t offset_output = (b + s * B) * 3 * W * Z + id_in_warp * nvec + blockIdx.y * W;
T *my_output = qkv + offset_output;
for (int i = 0; i < Z; ++i) {
uint64_t *out = reinterpret_cast<uint64_t *>(my_output + i * load_size * 3);
*out = *reinterpret_cast<const uint64_t *>(my_input + i * load_size);
}
}
void prepare_flash_attn_fwd(Tensor qkvi, Tensor qkv, cudaStream_t stream) {
using namespace transformer_engine;
NVTE_CHECK(qkvi.dim() == 4, "Expected 4-dim tensor.");
NVTE_CHECK(qkvi.dtype() == DType::kFloat16 || qkvi.dtype() == DType::kBFloat16);
auto qkvi_shape = qkvi.shape();
NVTE_CHECK(qkvi_shape[3] % load_size == 0);
NVTE_CHECK(qkvi_shape[3] == load_size);
// [s, b, n, h * 3] -> [3, b, s, n, h]
std::vector<uint64_t> shape = {3, qkvi_shape[1], qkvi_shape[0], qkvi_shape[2], qkvi_shape[3]};
size_t warps = qkvi_shape[0] * qkvi_shape[1];
size_t warps_per_block = block_size / warp_size;
size_t blocks = (warps + warps_per_block - 1) / warps_per_block;
dim3 grid(blocks, 3);
int threads = block_size;
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(
qkvi.dtype(), dtype,
prepare_kernel_fwd<dtype><<<grid, threads, 0, stream>>>(
reinterpret_cast<dtype *>(qkvi.data.dptr), reinterpret_cast<dtype *>(qkv.data.dptr),
shape[1], shape[2], shape[3], shape[4]););
}
void prepare_flash_attn_bwd(Tensor q, Tensor k, Tensor v, Tensor qkv, cudaStream_t stream) {
using namespace transformer_engine;
NVTE_CHECK(q.dim() == 4, "Expected 4-dim tensor.");
NVTE_CHECK(k.dim() == 4, "Expected 4-dim tensor.");
NVTE_CHECK(v.dim() == 4, "Expected 4-dim tensor.");
NVTE_CHECK(q.dtype() == DType::kFloat16 || q.dtype() == DType::kBFloat16);
NVTE_CHECK(k.dtype() == q.dtype());
NVTE_CHECK(v.dtype() == q.dtype());
auto q_shape = q.shape();
auto k_shape = k.shape();
auto v_shape = v.shape();
NVTE_CHECK(q_shape[3] % load_size == 0);
NVTE_CHECK(q_shape[3] == load_size);
NVTE_CHECK(k_shape[3] % load_size == 0);
NVTE_CHECK(k_shape[3] == load_size);
NVTE_CHECK(v_shape[3] % load_size == 0);
NVTE_CHECK(v_shape[3] == load_size);
// 3 x [s, b, n, h] -> [b, s, n, 3 * h]
std::vector<uint64_t> shape = {q_shape[1], q_shape[0], q_shape[2], 3 * q_shape[3]};
size_t warps = q_shape[0] * q_shape[1];
size_t warps_per_block = block_size / warp_size;
size_t blocks = (warps + warps_per_block - 1) / warps_per_block;
dim3 grid(blocks, 3);
int threads = block_size;
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(
q.dtype(), dtype,
prepare_kernel_bwd<dtype><<<grid, threads, 0, stream>>>(
reinterpret_cast<dtype *>(q.data.dptr), reinterpret_cast<dtype *>(k.data.dptr),
reinterpret_cast<dtype *>(v.data.dptr), reinterpret_cast<dtype *>(qkv.data.dptr),
q_shape[0], q_shape[1], q_shape[2], q_shape[3]););
}
} // namespace flash_attention
} // namespace transformer_engine
void nvte_prepare_flash_attn_fwd(NVTETensor qkvi, NVTETensor qkv, cudaStream_t stream) {
NVTE_API_CALL(nvte_prepare_flash_attn_fwd);
using namespace transformer_engine;
flash_attention::prepare_flash_attn_fwd(*reinterpret_cast<Tensor *>(qkvi),
*reinterpret_cast<Tensor *>(qkv), stream);
}
void nvte_prepare_flash_attn_bwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTETensor qkv,
cudaStream_t stream) {
NVTE_API_CALL(nvte_prepare_flash_attn_bwd);
using namespace transformer_engine;
flash_attention::prepare_flash_attn_bwd(
*reinterpret_cast<Tensor *>(q), *reinterpret_cast<Tensor *>(k),
*reinterpret_cast<Tensor *>(v), *reinterpret_cast<Tensor *>(qkv), stream);
}
......@@ -3,48 +3,15 @@
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_KV_CACHE_CUH_
#define TRANSFORMER_ENGINE_FUSED_ATTN_KV_CACHE_CUH_
namespace transformer_engine {
namespace fused_attn {
template <typename scalar_t>
__global__ void convert_thd_to_bshd_kernel(scalar_t *tensor, scalar_t *new_tensor, int *cu_seqlens,
int b, int max_seq_len, int h, int d) {
// tensor: thd; new_tensor: bshd
// cu_seqlens: [b + 1]
for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
int num_elts = (cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx]) * h * d;
int thd_offset = cu_seqlens[batch_idx] * h * d;
int bshd_offset = batch_idx * max_seq_len * h * d;
scalar_t *thd_token = tensor + thd_offset;
scalar_t *bshd_token = new_tensor + bshd_offset;
for (int i = threadIdx.x; i < num_elts; i += blockDim.x) {
*(bshd_token + i) = *(thd_token + i);
}
}
}
#include "../common.h"
#include "transformer_engine/fused_attn.h"
template <typename scalar_t>
__global__ void convert_bshd_to_thd_kernel(scalar_t *tensor, scalar_t *new_tensor, int *cu_seqlens,
int b, int max_seq_len, int h, int d) {
// tensor: bshd; new_tensor: thd
// cu_seqlens: [b + 1]
for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
int seqlen = cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx];
int num_elts = seqlen * h * d;
int bshd_offset = batch_idx * max_seq_len * h * d;
int thd_offset = cu_seqlens[batch_idx] * h * d;
scalar_t *bshd_token = tensor + bshd_offset;
scalar_t *thd_token = new_tensor + thd_offset;
for (int i = threadIdx.x; i < num_elts; i += blockDim.x) {
*(thd_token + i) = *(bshd_token + i);
}
}
}
namespace transformer_engine {
namespace kv_cache {
template <typename scalar_t>
__global__ void reindex_kv_cache_kernel(scalar_t *k_cache, scalar_t *v_cache, int *batch_indices,
template <typename dtype>
__global__ void reindex_kv_cache_kernel(dtype *k_cache, dtype *v_cache, int *batch_indices,
int *cu_new_lens, int *cu_cached_lens, int h_kv, int d_k,
int d_v, int b, int max_seq_len) {
// k_cache, v_cache: bshd
......@@ -75,11 +42,11 @@ __global__ void reindex_kv_cache_kernel(scalar_t *k_cache, scalar_t *v_cache, in
}
}
template <typename scalar_t>
__global__ void copy_to_kv_cache_kernel(scalar_t *new_k, scalar_t *new_v, scalar_t *k_cache,
scalar_t *v_cache, int *page_table, int *cu_new_lens,
int *cu_cached_lens, NVTE_QKV_Format qkv_format, int h_kv,
int d_k, int d_v, int b, int max_ctx_len, int max_seq_len,
template <typename dtype>
__global__ void copy_to_kv_cache_kernel(dtype *new_k, dtype *new_v, dtype *k_cache, dtype *v_cache,
int *page_table, int *cu_new_lens, int *cu_cached_lens,
NVTE_QKV_Format qkv_format, int h_kv, int d_k, int d_v,
int b, int max_ctx_len, int max_seq_len,
int max_pages_per_seq, bool is_non_paged) {
// new_k, new_v: qkv_format; k_cache, v_cache: bshd
// cu_new_lens, cu_cached_lens: [b + 1]
......@@ -140,6 +107,191 @@ __global__ void copy_to_kv_cache_kernel(scalar_t *new_k, scalar_t *new_v, scalar
}
}
}
} // namespace fused_attn
template <typename dtype>
void copy_to_kv_cache_launcher(Tensor new_k, Tensor new_v, Tensor k_cache, Tensor v_cache,
Tensor page_table, Tensor cu_new_lens, Tensor cu_cached_lens,
NVTE_QKV_Format qkv_format, int h_kv, int d_k, int d_v, int b,
int max_ctx_len, int max_seq_len, int max_pages_per_seq,
bool is_non_paged, cudaStream_t stream) {
if (new_k.has_data() && new_v.has_data() && k_cache.has_data() && v_cache.has_data()) {
if (is_non_paged) {
reindex_kv_cache_kernel<<<16, 256, 0, stream>>>(
reinterpret_cast<dtype *>(k_cache.data.dptr),
reinterpret_cast<dtype *>(v_cache.data.dptr),
reinterpret_cast<int *>(page_table.data.dptr),
reinterpret_cast<int *>(cu_new_lens.data.dptr),
reinterpret_cast<int *>(cu_cached_lens.data.dptr), h_kv, d_k, d_v, b, max_seq_len);
}
copy_to_kv_cache_kernel<<<16, 256, 0, stream>>>(
reinterpret_cast<dtype *>(new_k.data.dptr), reinterpret_cast<dtype *>(new_v.data.dptr),
reinterpret_cast<dtype *>(k_cache.data.dptr), reinterpret_cast<dtype *>(v_cache.data.dptr),
reinterpret_cast<int *>(page_table.data.dptr),
reinterpret_cast<int *>(cu_new_lens.data.dptr),
reinterpret_cast<int *>(cu_cached_lens.data.dptr), qkv_format, h_kv, d_k, d_v, b,
max_ctx_len, max_seq_len, max_pages_per_seq, is_non_paged);
}
}
void copy_to_kv_cache(Tensor new_k, Tensor new_v, Tensor k_cache, Tensor v_cache, Tensor page_table,
Tensor cu_new_lens, Tensor cu_cached_lens, NVTE_QKV_Format qkv_format, int b,
int max_ctx_len, int max_seq_len, int max_pages_per_seq, bool is_non_paged,
cudaStream_t stream) {
int h_kv = new_k.shape()[new_k.dim() - 2];
int d_k = new_k.shape()[new_k.dim() - 1];
int d_v = new_v.shape()[new_v.dim() - 1];
NVTE_CHECK(k_cache.dtype() == v_cache.dtype() && new_k.dtype() == new_v.dtype() &&
new_k.dtype() == k_cache.dtype(),
"new_k, new_v, k_cache and v_cache must be of the same data type.");
NVTE_CHECK(qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD ||
qkv_format == NVTE_QKV_Format::NVTE_THD,
"qkv_format must be {BSHD, SBHD, THD}.");
TRANSFORMER_ENGINE_TYPE_SWITCH_FLOAT(
k_cache.dtype(), dtype,
copy_to_kv_cache_launcher<dtype>(new_k, new_v, k_cache, v_cache, page_table, cu_new_lens,
cu_cached_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len,
max_seq_len, max_pages_per_seq, is_non_paged, stream););
}
template <typename scalar_t>
__global__ void convert_thd_to_bshd_kernel(scalar_t *tensor, scalar_t *new_tensor, int *cu_seqlens,
int b, int max_seq_len, int h, int d) {
// tensor: thd; new_tensor: bshd
// cu_seqlens: [b + 1]
for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
int num_elts = (cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx]) * h * d;
int thd_offset = cu_seqlens[batch_idx] * h * d;
int bshd_offset = batch_idx * max_seq_len * h * d;
scalar_t *thd_token = tensor + thd_offset;
scalar_t *bshd_token = new_tensor + bshd_offset;
for (int i = threadIdx.x; i < num_elts; i += blockDim.x) {
*(bshd_token + i) = *(thd_token + i);
}
}
}
template <typename scalar_t>
void convert_thd_to_bshd_launcher(Tensor tensor, Tensor new_tensor, Tensor cu_seqlens, int b,
int max_seq_len, int h, int d, cudaStream_t stream) {
using namespace transformer_engine;
convert_thd_to_bshd_kernel<<<16, 256, 0, stream>>>(
reinterpret_cast<scalar_t *>(tensor.data.dptr),
reinterpret_cast<scalar_t *>(new_tensor.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), b, max_seq_len, h, d);
}
void convert_thd_to_bshd(Tensor tensor, Tensor cu_seqlens, Tensor new_tensor, int b,
int max_seq_len, cudaStream_t stream) {
using namespace transformer_engine;
auto tensor_shape = tensor.shape();
TRANSFORMER_ENGINE_TYPE_SWITCH_FLOAT(
new_tensor.dtype(), dtype,
convert_thd_to_bshd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len,
tensor_shape[1], tensor_shape[2], stream););
}
template <typename scalar_t>
__global__ void convert_bshd_to_thd_kernel(scalar_t *tensor, scalar_t *new_tensor, int *cu_seqlens,
int b, int max_seq_len, int h, int d) {
// tensor: bshd; new_tensor: thd
// cu_seqlens: [b + 1]
for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
int seqlen = cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx];
int num_elts = seqlen * h * d;
int bshd_offset = batch_idx * max_seq_len * h * d;
int thd_offset = cu_seqlens[batch_idx] * h * d;
scalar_t *bshd_token = tensor + bshd_offset;
scalar_t *thd_token = new_tensor + thd_offset;
for (int i = threadIdx.x; i < num_elts; i += blockDim.x) {
*(thd_token + i) = *(bshd_token + i);
}
}
}
template <typename scalar_t>
void convert_bshd_to_thd_launcher(Tensor tensor, Tensor new_tensor, Tensor cu_seqlens, int b,
int max_seq_len, int h, int d, cudaStream_t stream) {
using namespace transformer_engine;
convert_bshd_to_thd_kernel<<<16, 256, 0, stream>>>(
reinterpret_cast<scalar_t *>(tensor.data.dptr),
reinterpret_cast<scalar_t *>(new_tensor.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), b, max_seq_len, h, d);
}
void convert_bshd_to_thd(Tensor tensor, Tensor cu_seqlens, Tensor new_tensor, int t,
cudaStream_t stream) {
using namespace transformer_engine;
auto tensor_shape = tensor.shape();
TRANSFORMER_ENGINE_TYPE_SWITCH_FLOAT(
tensor.dtype(), dtype,
convert_bshd_to_thd_launcher<dtype>(tensor, new_tensor, cu_seqlens, tensor_shape[0],
tensor_shape[1], tensor_shape[2], tensor_shape[3],
stream););
}
} // namespace kv_cache
} // namespace transformer_engine
#endif
/***************************************************************************************************
* KV Cache: Copy new KV tokens to the KV cache
* 1. new_k and new_v are in qkv_format; k_cache and v_cache are in 'bshd' format
* 2. cu_new_lens and cu_cached_lens are in shape [b + 1]; cu_cached_lens include the added lens
* in current step
* 3. Non-paged KV cache is a special case of paged KV cache, with page_table = [b, 1] and
* max_pages_per_seq = 1. We use the same underlying kernel for both non-paged and paged.
* Set is_non_paged = True/False to indicate as such.
* 4. is_non_paged = True also re-indexes the KV cache, e.g. the initial batch indices [0, 3, 1, 2]
* becomes [0, 1, 1, 2]. The page_table = batch_indices.unsqueeze(1) is however unchanged.
* batch_indices_post can be used for monotonical indexing, i.e. [0, 1, 2, 3]. batch_indices is
* preserved for the next layer in the same iteration.
* 5. Only supports same page_table for k_cache and v_cache
* 6. Only pad_between_seqs = False when qkv_format = thd, i.e. there should be no pad tokens
* between sequences in new_k and new_v such as [a a a 0..0 b b 0..0 c 0..0].
**************************************************************************************************/
void nvte_copy_to_kv_cache(NVTETensor new_k, NVTETensor new_v, NVTETensor k_cache,
NVTETensor v_cache, NVTETensor page_table, NVTETensor cu_new_lens,
NVTETensor cu_cached_lens, NVTE_QKV_Format qkv_format, int b,
int max_ctx_len, int max_seq_len, int max_pages_per_seq,
int is_non_paged, cudaStream_t stream) {
NVTE_API_CALL(nvte_copy_to_kv_cache);
using namespace transformer_engine;
kv_cache::copy_to_kv_cache(
*reinterpret_cast<Tensor *>(new_k), *reinterpret_cast<Tensor *>(new_v),
*reinterpret_cast<Tensor *>(k_cache), *reinterpret_cast<Tensor *>(v_cache),
*reinterpret_cast<Tensor *>(page_table), *reinterpret_cast<Tensor *>(cu_new_lens),
*reinterpret_cast<Tensor *>(cu_cached_lens), qkv_format, b, max_ctx_len, max_seq_len,
max_pages_per_seq, is_non_paged, stream);
}
/***************************************************************************************************
* KV Cache: Convert a tensor from qkv_format = thd to qkv_format = bshd
**************************************************************************************************/
void nvte_convert_thd_to_bshd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETensor new_tensor,
int b, int max_seq_len, cudaStream_t stream) {
NVTE_API_CALL(nvte_convert_thd_to_bshd);
using namespace transformer_engine;
kv_cache::convert_thd_to_bshd(*reinterpret_cast<Tensor *>(tensor),
*reinterpret_cast<Tensor *>(cu_seqlens),
*reinterpret_cast<Tensor *>(new_tensor), b, max_seq_len, stream);
}
/***************************************************************************************************
* KV Cache: Convert a tensor from qkv_format = bshd to qkv_format = thd
**************************************************************************************************/
void nvte_convert_bshd_to_thd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETensor new_tensor,
int t, cudaStream_t stream) {
NVTE_API_CALL(nvte_convert_bshd_to_thd);
using namespace transformer_engine;
kv_cache::convert_bshd_to_thd(*reinterpret_cast<Tensor *>(tensor),
*reinterpret_cast<Tensor *>(cu_seqlens),
*reinterpret_cast<Tensor *>(new_tensor), t, stream);
}
......@@ -610,5 +610,27 @@ uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cud
return hout;
}
__global__ void extract_seed_and_offset(int64_t *rng_state_ptr, bool captured, int64_t *seed_ptr,
uint64_t seed_val, int64_t *offset_ptr, uint64_t offset_val,
uint32_t offset_intragraph) {
if (captured) {
rng_state_ptr[0] = *seed_ptr;
rng_state_ptr[1] = static_cast<int64_t>(*offset_ptr + static_cast<int64_t>(offset_intragraph));
} else {
rng_state_ptr[0] = static_cast<int64_t>(seed_val);
rng_state_ptr[1] = static_cast<int64_t>(offset_val);
}
}
} // namespace fused_attn
} // namespace transformer_engine
void nvte_extract_seed_and_offset(int64_t *rng_state_ptr, int captured, int64_t *seed_ptr,
uint64_t seed_val, int64_t *offset_ptr, uint64_t offset_val,
uint32_t offset_intragraph, cudaStream_t stream) {
NVTE_API_CALL(nvte_extract_seed_and_offset);
using namespace transformer_engine;
fused_attn::extract_seed_and_offset<<<1, 1, 0, stream>>>(
rng_state_ptr, captured, seed_ptr, seed_val, offset_ptr, offset_val, offset_intragraph);
}
......@@ -244,7 +244,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S,
NVTETensor O, NVTETensorPack* Aux_CTX_Tensors,
NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded,
const NVTETensor rng_state, size_t max_seqlen, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
......@@ -300,7 +300,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
*/
void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO,
const NVTETensor S, NVTETensor dP,
const NVTETensorPack* Aux_CTX_Tensors, NVTETensor dQKV,
const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV,
NVTETensor dBias, const NVTETensor cu_seqlens,
const NVTETensor cu_seqlens_padded, size_t max_seqlen,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
......@@ -368,7 +368,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
*/
void nvte_fused_attn_fwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O,
NVTETensorPack* Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout,
......@@ -429,7 +429,7 @@ void nvte_fused_attn_fwd_kvpacked(
*/
void nvte_fused_attn_bwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO,
const NVTETensor S, NVTETensor dP, const NVTETensorPack* Aux_CTX_Tensors, NVTETensor dQ,
const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ,
NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout,
......@@ -500,7 +500,7 @@ void nvte_fused_attn_bwd_kvpacked(
*/
void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor Bias, NVTETensor S, NVTETensor O,
NVTETensorPack* Aux_CTX_Tensors, const NVTETensor cu_seqlens_q,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
const NVTETensor page_table_v, const NVTETensor rng_state,
......@@ -569,7 +569,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
*/
void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP,
const NVTETensorPack* Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK,
const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK,
NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q,
......@@ -604,6 +604,51 @@ void nvte_populate_rng_state_async(NVTETensor rng_state_dst, const NVTETensor se
uint32_t nvte_get_runtime_num_segments(NVTETensor cu_seqlen, NVTETensor workspace, size_t len,
cudaStream_t stream);
void nvte_extract_seed_and_offset(int64_t *rng_state_ptr, int captured, int64_t *seed_ptr,
uint64_t seed_val, int64_t *offset_ptr, uint64_t offset_val,
uint32_t offset_intragraph, cudaStream_t stream);
void nvte_copy_to_kv_cache(NVTETensor new_k, NVTETensor new_v, NVTETensor k_cache,
NVTETensor v_cache, NVTETensor page_table, NVTETensor cu_new_lens,
NVTETensor cu_cached_lens, NVTE_QKV_Format qkv_format, int b,
int max_ctx_len, int max_seq_len, int max_pages_per_seq,
int is_non_paged, cudaStream_t stream);
void nvte_cp_thd_read_half_tensor(const NVTETensor &tensor, const NVTETensor &cu_seqlens,
NVTETensor half, int half_idx, cudaStream_t stream);
void nvte_cp_thd_second_half_lse_correction(NVTETensor lse, const NVTETensor &lse_per_step,
const NVTETensor &cu_seqlens, int lse_packed,
cudaStream_t stream);
void nvte_cp_thd_read_second_half_lse(const NVTETensor &lse, const NVTETensor &cu_seqlens,
NVTETensor half_lse, int lse_packed,
int second_half_lse_seqlen, cudaStream_t stream);
void nvte_cp_thd_out_correction(NVTETensor out, const NVTETensor &out_per_step,
const NVTETensor &lse, const NVTETensor &lse_per_step,
const NVTETensor &cu_seqlens, int only_second_half, int lse_packed,
cudaStream_t stream);
void nvte_cp_thd_grad_correction(NVTETensor grad, const NVTETensor &grad_per_step,
const NVTETensor &cu_seqlens, const char *first_half,
const char *second_half, cudaStream_t stream);
void nvte_cp_thd_get_partitioned_indices(const NVTETensor &cu_seqlens, NVTETensor output,
int total_tokens, int world_size, int rank,
cudaStream_t stream);
void nvte_convert_thd_to_bshd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETensor new_tensor,
int b, int max_seq_len, cudaStream_t stream);
void nvte_convert_bshd_to_thd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETensor new_tensor,
int t, cudaStream_t stream);
void nvte_prepare_flash_attn_fwd(NVTETensor qkvi, NVTETensor qkv, cudaStream_t stream);
void nvte_prepare_flash_attn_bwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTETensor qkv,
cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
......
......@@ -218,6 +218,26 @@ std::vector<size_t> getTensorShape(at::Tensor t);
transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid,
const std::string& fp8_recipe);
inline size_t typeToSize(transformer_engine::DType t) {
switch (t) {
case transformer_engine::DType::kInt64:
return 8;
case transformer_engine::DType::kInt32:
case transformer_engine::DType::kFloat32:
return 4;
case transformer_engine::DType::kInt16:
case transformer_engine::DType::kFloat16:
case transformer_engine::DType::kBFloat16:
return 2;
case transformer_engine::DType::kByte:
case transformer_engine::DType::kFloat8E4M3:
case transformer_engine::DType::kFloat8E5M2:
return 1;
default:
NVTE_ERROR("Invalid type");
}
}
inline at::ScalarType GetATenDType(transformer_engine::DType t) {
switch (t) {
case transformer_engine::DType::kInt16:
......
......@@ -72,10 +72,10 @@ at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v);
at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len);
at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int t);
void copy_to_kv_cache(torch::Tensor new_k, torch::Tensor new_v, torch::Tensor k_cache,
torch::Tensor v_cache, torch::Tensor page_table, torch::Tensor cu_new_lens,
torch::Tensor cu_cached_lens, NVTE_QKV_Format kv_format, int b,
int max_ctx_len, int max_seq_len, int max_pages_per_seq, bool is_non_paged);
void copy_to_kv_cache(at::Tensor new_k, at::Tensor new_v, at::Tensor k_cache, at::Tensor v_cache,
at::Tensor page_table, at::Tensor cu_new_lens, at::Tensor cu_cached_lens,
NVTE_QKV_Format kv_format, int b, int max_ctx_len, int max_seq_len,
int max_pages_per_seq, bool is_non_paged);
/***************************************************************************************************
* GEMM
......@@ -392,12 +392,11 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
namespace nvshmem_api {
void init_nvshmem_backend(c10d::ProcessGroup *process_group);
torch::Tensor create_nvshmem_tensor(const std::vector<int64_t> &shape, c10::ScalarType dtype);
at::Tensor create_nvshmem_tensor(const std::vector<int64_t> &shape, c10::ScalarType dtype);
void nvshmem_send_on_current_stream(torch::Tensor src, torch::Tensor dst, int peer,
torch::Tensor signal);
void nvshmem_send_on_current_stream(at::Tensor src, at::Tensor dst, int peer, at::Tensor signal);
void nvshmem_wait_on_current_stream(torch::Tensor signal, const std::string &wait_kind);
void nvshmem_wait_on_current_stream(at::Tensor signal, const std::string &wait_kind);
void nvshmem_finalize();
} // namespace nvshmem_api
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment