Unverified Commit 51cd4415 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[C][PyTorch]Make pytorch extensions pure cpp (#1754)



* First pass refactor
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* first pass
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* core compiles
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Include cuda dirs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Compiles
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Move grad outside autocast
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix kv cache
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Address review comments
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Change src file name in cmake
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* move the kernels too
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Move comment
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Move comments around
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* more movement
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* move
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent dfe1a65a
...@@ -8,11 +8,7 @@ from pathlib import Path ...@@ -8,11 +8,7 @@ from pathlib import Path
import setuptools import setuptools
from .utils import ( from .utils import all_files_in_dir, cuda_version, get_cuda_include_dirs
all_files_in_dir,
cuda_archs,
cuda_version,
)
def setup_pytorch_extension( def setup_pytorch_extension(
...@@ -30,55 +26,30 @@ def setup_pytorch_extension( ...@@ -30,55 +26,30 @@ def setup_pytorch_extension(
] + all_files_in_dir(extensions_dir) ] + all_files_in_dir(extensions_dir)
# Header files # Header files
include_dirs = [ include_dirs = get_cuda_include_dirs()
common_header_files, include_dirs.extend(
common_header_files / "common", [
common_header_files / "common" / "include", common_header_files,
csrc_header_files, common_header_files / "common",
] common_header_files / "common" / "include",
csrc_header_files,
]
)
# Compiler flags # Compiler flags
cxx_flags = [ cxx_flags = [
"-O3", "-O3",
"-fvisibility=hidden", "-fvisibility=hidden",
] ]
nvcc_flags = [
"-O3",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
]
cuda_architectures = cuda_archs()
if "70" in cuda_architectures:
nvcc_flags.extend(["-gencode", "arch=compute_70,code=sm_70"])
# Version-dependent CUDA options # Version-dependent CUDA options
try: try:
version = cuda_version() version = cuda_version()
except FileNotFoundError: except FileNotFoundError:
print("Could not determine CUDA Toolkit version") print("Could not determine CUDA version")
else: else:
if version < (12, 0): if version < (12, 0):
raise RuntimeError("Transformer Engine requires CUDA 12.0 or newer") raise RuntimeError("Transformer Engine requires CUDA 12.0 or newer")
nvcc_flags.extend(
(
"--threads",
os.getenv("NVTE_BUILD_THREADS_PER_JOB", "1"),
)
)
for arch in cuda_architectures.split(";"):
if arch == "70":
continue # Already handled
nvcc_flags.extend(["-gencode", f"arch=compute_{arch},code=sm_{arch}"])
if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))): if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
assert ( assert (
...@@ -87,7 +58,6 @@ def setup_pytorch_extension( ...@@ -87,7 +58,6 @@ def setup_pytorch_extension(
mpi_path = Path(os.getenv("MPI_HOME")) mpi_path = Path(os.getenv("MPI_HOME"))
include_dirs.append(mpi_path / "include") include_dirs.append(mpi_path / "include")
cxx_flags.append("-DNVTE_UB_WITH_MPI") cxx_flags.append("-DNVTE_UB_WITH_MPI")
nvcc_flags.append("-DNVTE_UB_WITH_MPI")
library_dirs = [] library_dirs = []
libraries = [] libraries = []
...@@ -100,21 +70,17 @@ def setup_pytorch_extension( ...@@ -100,21 +70,17 @@ def setup_pytorch_extension(
library_dirs.append(nvshmem_home / "lib") library_dirs.append(nvshmem_home / "lib")
libraries.append("nvshmem_host") libraries.append("nvshmem_host")
cxx_flags.append("-DNVTE_ENABLE_NVSHMEM") cxx_flags.append("-DNVTE_ENABLE_NVSHMEM")
nvcc_flags.append("-DNVTE_ENABLE_NVSHMEM")
# Construct PyTorch CUDA extension # Construct PyTorch CUDA extension
sources = [str(path) for path in sources] sources = [str(path) for path in sources]
include_dirs = [str(path) for path in include_dirs] include_dirs = [str(path) for path in include_dirs]
from torch.utils.cpp_extension import CUDAExtension from torch.utils.cpp_extension import CppExtension
return CUDAExtension( return CppExtension(
name="transformer_engine_torch", name="transformer_engine_torch",
sources=[str(src) for src in sources], sources=[str(src) for src in sources],
include_dirs=[str(inc) for inc in include_dirs], include_dirs=[str(inc) for inc in include_dirs],
extra_compile_args={ extra_compile_args={"cxx": cxx_flags},
"cxx": cxx_flags,
"nvcc": nvcc_flags,
},
libraries=[str(lib) for lib in libraries], libraries=[str(lib) for lib in libraries],
library_dirs=[str(lib_dir) for lib_dir in library_dirs], library_dirs=[str(lib_dir) for lib_dir in library_dirs],
) )
...@@ -1628,8 +1628,8 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoP ...@@ -1628,8 +1628,8 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoP
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_kv=cu_seqlens_kv,
) )
if is_training: if is_training:
out.backward(out_grad) out.backward(out_grad)
param_names = [] param_names = []
param_names.append("hidden_states.grad") param_names.append("hidden_states.grad")
...@@ -1879,8 +1879,8 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training): ...@@ -1879,8 +1879,8 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training):
checkpoint_core_attention=False, checkpoint_core_attention=False,
core_attention_bias_type=config.attn_bias_type, core_attention_bias_type=config.attn_bias_type,
) )
if is_training: if is_training:
out.backward(out_grad) out.backward(out_grad)
if is_training: if is_training:
return out, (inp[0].grad, inp[1].grad, inp[2].grad) return out, (inp[0].grad, inp[1].grad, inp[2].grad)
...@@ -1993,7 +1993,7 @@ def _run_custom_mha_fp8(dtype, config, backend): ...@@ -1993,7 +1993,7 @@ def _run_custom_mha_fp8(dtype, config, backend):
mha = Custom_MHA_FP8(config).to(dtype=dtype, device="cuda") mha = Custom_MHA_FP8(config).to(dtype=dtype, device="cuda")
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
out = mha(inp, cu_seqlens, config.max_seqlen_q) out = mha(inp, cu_seqlens, config.max_seqlen_q)
out.backward(out_grad) out.backward(out_grad)
out = torch.load("out.pt") out = torch.load("out.pt")
dqkv = torch.load("dqkv.pt") dqkv = torch.load("dqkv.pt")
......
...@@ -130,18 +130,20 @@ def dtype_tols(dtype: torch.dtype) -> Dict[str, float]: ...@@ -130,18 +130,20 @@ def dtype_tols(dtype: torch.dtype) -> Dict[str, float]:
def assert_allclose( def assert_allclose(
l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float, rtol: float = None l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float = None, rtol: float = None
) -> bool: ) -> bool:
"""Ensures two lists are equal.""" """Ensures two lists are equal."""
assert len(l1) == len(l2), "Unequal number of outputs." assert len(l1) == len(l2), "Unequal number of outputs."
for i, (t1, t2) in enumerate(zip(l1, l2)): for i, (t1, t2) in enumerate(zip(l1, l2)):
tols = dict(atol=atol) tols = dtype_tols(t2.dtype)
if rtol is not None: if rtol is not None:
tols["rtol"] = rtol tols["rtol"] = rtol
if atol is not None:
tols["atol"] = atol
result = torch.allclose(t1, t2, **tols) result = torch.allclose(t1, t2, **tols)
if not result: if not result:
diff = torch.abs(t1 - t2) diff = torch.abs(t1 - t2)
tol = atol + (rtol * torch.abs(t2)) tol = tols["atol"] + (tols["rtol"] * torch.abs(t2))
exceed_mask = diff > tol exceed_mask = diff > tol
if exceed_mask.any(): if exceed_mask.any():
indices = torch.nonzero(exceed_mask, as_tuple=True) indices = torch.nonzero(exceed_mask, as_tuple=True)
......
...@@ -66,6 +66,9 @@ list(APPEND transformer_engine_SOURCES ...@@ -66,6 +66,9 @@ list(APPEND transformer_engine_SOURCES
transpose/quantize_transpose_square_blockwise.cu transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise.cu transpose/quantize_transpose_vector_blockwise.cu
activation/gelu.cu activation/gelu.cu
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
fused_attn/kv_cache.cu
fused_attn/fused_attn_f16_max512_seqlen.cu fused_attn/fused_attn_f16_max512_seqlen.cu
fused_attn/fused_attn_f16_arbitrary_seqlen.cu fused_attn/fused_attn_f16_arbitrary_seqlen.cu
activation/relu.cu activation/relu.cu
...@@ -173,6 +176,9 @@ set_source_files_properties(fused_softmax/scaled_masked_softmax.cu ...@@ -173,6 +176,9 @@ set_source_files_properties(fused_softmax/scaled_masked_softmax.cu
multi_tensor/l2norm.cu multi_tensor/l2norm.cu
multi_tensor/scale.cu multi_tensor/scale.cu
multi_tensor/sgd.cu multi_tensor/sgd.cu
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
fused_attn/kv_cache.cu
PROPERTIES PROPERTIES
COMPILE_OPTIONS "--use_fast_math") COMPILE_OPTIONS "--use_fast_math")
option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --use_fast_math option" OFF) option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --use_fast_math option" OFF)
......
...@@ -111,7 +111,7 @@ struct Tensor { ...@@ -111,7 +111,7 @@ struct Tensor {
columnwise_scale_inv(nullptr, {1}, DType::kFloat32), columnwise_scale_inv(nullptr, {1}, DType::kFloat32),
scaling_mode(NVTE_DELAYED_TENSOR_SCALING) {} scaling_mode(NVTE_DELAYED_TENSOR_SCALING) {}
int numel() const { size_t numel() const {
size_t acc = 1; size_t acc = 1;
for (const auto dim : shape()) { for (const auto dim : shape()) {
acc *= dim; acc *= dim;
...@@ -133,6 +133,14 @@ struct Tensor { ...@@ -133,6 +133,14 @@ struct Tensor {
return data.dtype; return data.dtype;
} }
size_t dim() const {
if (!has_data() && has_columnwise_data()) {
return columnwise_data.shape.size();
} else {
return data.shape.size();
}
}
std::vector<size_t> shape() const { std::vector<size_t> shape() const {
/* Note: We sometimes experience spurious compiler errors /* Note: We sometimes experience spurious compiler errors
* (-Wstringop-overflow) from this function. It appears that GCC * (-Wstringop-overflow) from this function. It appears that GCC
...@@ -385,6 +393,33 @@ struct TypeInfo { ...@@ -385,6 +393,33 @@ struct TypeInfo {
NVTE_ERROR("Invalid type."); \ NVTE_ERROR("Invalid type."); \
} }
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FLOAT(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
case DType::kFloat32: { \
using type = float; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat16: { \
using type = fp16; \
{ __VA_ARGS__ } \
} break; \
case DType::kBFloat16: { \
using type = bf16; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat8E4M3: { \
using type = fp8e4m3; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat8E5M2: { \
using type = fp8e5m2; \
{ __VA_ARGS__ } \
} break; \
default: \
NVTE_ERROR("Invalid type."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(dtype, type, ...) \ #define TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(dtype, type, ...) \
switch (dtype) { \ switch (dtype) { \
using namespace transformer_engine; \ using namespace transformer_engine; \
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "../common.h"
#include "transformer_engine/fused_attn.h"
namespace transformer_engine {
namespace flash_attention {
constexpr int warp_size = 32;
constexpr int type_size = 2; // FP16 or BF16
constexpr int nvec = sizeof(uint64_t) / type_size;
constexpr int load_size = warp_size * nvec;
constexpr int block_size = 512;
template <typename T>
__launch_bounds__(block_size) __global__
void prepare_kernel_fwd(const T *qkvi, T *qkv, const size_t B, const size_t S, const size_t Z,
const size_t W) {
const int warpid = (blockDim.x * blockIdx.x + threadIdx.x) / warp_size;
const int id_in_warp = threadIdx.x % warp_size;
const size_t offset_input = blockIdx.y * W + warpid * 3 * W * Z + id_in_warp * nvec;
const T *my_input = qkvi + offset_input;
const size_t s = warpid / B;
if (s >= S) return;
const size_t b = warpid % B;
const size_t offset_output = blockIdx.y * B * S * Z * W + (s + b * S) * W * Z + id_in_warp * nvec;
T *my_output = qkv + offset_output;
for (int i = 0; i < Z; ++i) {
uint64_t *out = reinterpret_cast<uint64_t *>(my_output + i * load_size);
*out = *reinterpret_cast<const uint64_t *>(my_input + i * load_size * 3);
}
}
template <typename T>
__launch_bounds__(block_size) __global__
void prepare_kernel_bwd(const T *q, const T *k, const T *v, T *qkv, const size_t B,
const size_t S, const size_t Z, const size_t W) {
const T *input = blockIdx.y == 0 ? q : (blockIdx.y == 1 ? k : v);
const int warpid = (blockDim.x * blockIdx.x + threadIdx.x) / warp_size;
const int id_in_warp = threadIdx.x % warp_size;
const size_t offset_input = warpid * W * Z + id_in_warp * nvec;
const T *my_input = input + offset_input;
const size_t b = warpid / S;
if (b >= B) return;
const size_t s = warpid % S;
const size_t offset_output = (b + s * B) * 3 * W * Z + id_in_warp * nvec + blockIdx.y * W;
T *my_output = qkv + offset_output;
for (int i = 0; i < Z; ++i) {
uint64_t *out = reinterpret_cast<uint64_t *>(my_output + i * load_size * 3);
*out = *reinterpret_cast<const uint64_t *>(my_input + i * load_size);
}
}
void prepare_flash_attn_fwd(Tensor qkvi, Tensor qkv, cudaStream_t stream) {
using namespace transformer_engine;
NVTE_CHECK(qkvi.dim() == 4, "Expected 4-dim tensor.");
NVTE_CHECK(qkvi.dtype() == DType::kFloat16 || qkvi.dtype() == DType::kBFloat16);
auto qkvi_shape = qkvi.shape();
NVTE_CHECK(qkvi_shape[3] % load_size == 0);
NVTE_CHECK(qkvi_shape[3] == load_size);
// [s, b, n, h * 3] -> [3, b, s, n, h]
std::vector<uint64_t> shape = {3, qkvi_shape[1], qkvi_shape[0], qkvi_shape[2], qkvi_shape[3]};
size_t warps = qkvi_shape[0] * qkvi_shape[1];
size_t warps_per_block = block_size / warp_size;
size_t blocks = (warps + warps_per_block - 1) / warps_per_block;
dim3 grid(blocks, 3);
int threads = block_size;
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(
qkvi.dtype(), dtype,
prepare_kernel_fwd<dtype><<<grid, threads, 0, stream>>>(
reinterpret_cast<dtype *>(qkvi.data.dptr), reinterpret_cast<dtype *>(qkv.data.dptr),
shape[1], shape[2], shape[3], shape[4]););
}
void prepare_flash_attn_bwd(Tensor q, Tensor k, Tensor v, Tensor qkv, cudaStream_t stream) {
using namespace transformer_engine;
NVTE_CHECK(q.dim() == 4, "Expected 4-dim tensor.");
NVTE_CHECK(k.dim() == 4, "Expected 4-dim tensor.");
NVTE_CHECK(v.dim() == 4, "Expected 4-dim tensor.");
NVTE_CHECK(q.dtype() == DType::kFloat16 || q.dtype() == DType::kBFloat16);
NVTE_CHECK(k.dtype() == q.dtype());
NVTE_CHECK(v.dtype() == q.dtype());
auto q_shape = q.shape();
auto k_shape = k.shape();
auto v_shape = v.shape();
NVTE_CHECK(q_shape[3] % load_size == 0);
NVTE_CHECK(q_shape[3] == load_size);
NVTE_CHECK(k_shape[3] % load_size == 0);
NVTE_CHECK(k_shape[3] == load_size);
NVTE_CHECK(v_shape[3] % load_size == 0);
NVTE_CHECK(v_shape[3] == load_size);
// 3 x [s, b, n, h] -> [b, s, n, 3 * h]
std::vector<uint64_t> shape = {q_shape[1], q_shape[0], q_shape[2], 3 * q_shape[3]};
size_t warps = q_shape[0] * q_shape[1];
size_t warps_per_block = block_size / warp_size;
size_t blocks = (warps + warps_per_block - 1) / warps_per_block;
dim3 grid(blocks, 3);
int threads = block_size;
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(
q.dtype(), dtype,
prepare_kernel_bwd<dtype><<<grid, threads, 0, stream>>>(
reinterpret_cast<dtype *>(q.data.dptr), reinterpret_cast<dtype *>(k.data.dptr),
reinterpret_cast<dtype *>(v.data.dptr), reinterpret_cast<dtype *>(qkv.data.dptr),
q_shape[0], q_shape[1], q_shape[2], q_shape[3]););
}
} // namespace flash_attention
} // namespace transformer_engine
void nvte_prepare_flash_attn_fwd(NVTETensor qkvi, NVTETensor qkv, cudaStream_t stream) {
NVTE_API_CALL(nvte_prepare_flash_attn_fwd);
using namespace transformer_engine;
flash_attention::prepare_flash_attn_fwd(*reinterpret_cast<Tensor *>(qkvi),
*reinterpret_cast<Tensor *>(qkv), stream);
}
void nvte_prepare_flash_attn_bwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTETensor qkv,
cudaStream_t stream) {
NVTE_API_CALL(nvte_prepare_flash_attn_bwd);
using namespace transformer_engine;
flash_attention::prepare_flash_attn_bwd(
*reinterpret_cast<Tensor *>(q), *reinterpret_cast<Tensor *>(k),
*reinterpret_cast<Tensor *>(v), *reinterpret_cast<Tensor *>(qkv), stream);
}
...@@ -3,48 +3,15 @@ ...@@ -3,48 +3,15 @@
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_KV_CACHE_CUH_
#define TRANSFORMER_ENGINE_FUSED_ATTN_KV_CACHE_CUH_
namespace transformer_engine { #include "../common.h"
namespace fused_attn { #include "transformer_engine/fused_attn.h"
template <typename scalar_t>
__global__ void convert_thd_to_bshd_kernel(scalar_t *tensor, scalar_t *new_tensor, int *cu_seqlens,
int b, int max_seq_len, int h, int d) {
// tensor: thd; new_tensor: bshd
// cu_seqlens: [b + 1]
for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
int num_elts = (cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx]) * h * d;
int thd_offset = cu_seqlens[batch_idx] * h * d;
int bshd_offset = batch_idx * max_seq_len * h * d;
scalar_t *thd_token = tensor + thd_offset;
scalar_t *bshd_token = new_tensor + bshd_offset;
for (int i = threadIdx.x; i < num_elts; i += blockDim.x) {
*(bshd_token + i) = *(thd_token + i);
}
}
}
template <typename scalar_t> namespace transformer_engine {
__global__ void convert_bshd_to_thd_kernel(scalar_t *tensor, scalar_t *new_tensor, int *cu_seqlens, namespace kv_cache {
int b, int max_seq_len, int h, int d) {
// tensor: bshd; new_tensor: thd
// cu_seqlens: [b + 1]
for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
int seqlen = cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx];
int num_elts = seqlen * h * d;
int bshd_offset = batch_idx * max_seq_len * h * d;
int thd_offset = cu_seqlens[batch_idx] * h * d;
scalar_t *bshd_token = tensor + bshd_offset;
scalar_t *thd_token = new_tensor + thd_offset;
for (int i = threadIdx.x; i < num_elts; i += blockDim.x) {
*(thd_token + i) = *(bshd_token + i);
}
}
}
template <typename scalar_t> template <typename dtype>
__global__ void reindex_kv_cache_kernel(scalar_t *k_cache, scalar_t *v_cache, int *batch_indices, __global__ void reindex_kv_cache_kernel(dtype *k_cache, dtype *v_cache, int *batch_indices,
int *cu_new_lens, int *cu_cached_lens, int h_kv, int d_k, int *cu_new_lens, int *cu_cached_lens, int h_kv, int d_k,
int d_v, int b, int max_seq_len) { int d_v, int b, int max_seq_len) {
// k_cache, v_cache: bshd // k_cache, v_cache: bshd
...@@ -75,11 +42,11 @@ __global__ void reindex_kv_cache_kernel(scalar_t *k_cache, scalar_t *v_cache, in ...@@ -75,11 +42,11 @@ __global__ void reindex_kv_cache_kernel(scalar_t *k_cache, scalar_t *v_cache, in
} }
} }
template <typename scalar_t> template <typename dtype>
__global__ void copy_to_kv_cache_kernel(scalar_t *new_k, scalar_t *new_v, scalar_t *k_cache, __global__ void copy_to_kv_cache_kernel(dtype *new_k, dtype *new_v, dtype *k_cache, dtype *v_cache,
scalar_t *v_cache, int *page_table, int *cu_new_lens, int *page_table, int *cu_new_lens, int *cu_cached_lens,
int *cu_cached_lens, NVTE_QKV_Format qkv_format, int h_kv, NVTE_QKV_Format qkv_format, int h_kv, int d_k, int d_v,
int d_k, int d_v, int b, int max_ctx_len, int max_seq_len, int b, int max_ctx_len, int max_seq_len,
int max_pages_per_seq, bool is_non_paged) { int max_pages_per_seq, bool is_non_paged) {
// new_k, new_v: qkv_format; k_cache, v_cache: bshd // new_k, new_v: qkv_format; k_cache, v_cache: bshd
// cu_new_lens, cu_cached_lens: [b + 1] // cu_new_lens, cu_cached_lens: [b + 1]
...@@ -140,6 +107,191 @@ __global__ void copy_to_kv_cache_kernel(scalar_t *new_k, scalar_t *new_v, scalar ...@@ -140,6 +107,191 @@ __global__ void copy_to_kv_cache_kernel(scalar_t *new_k, scalar_t *new_v, scalar
} }
} }
} }
} // namespace fused_attn
template <typename dtype>
void copy_to_kv_cache_launcher(Tensor new_k, Tensor new_v, Tensor k_cache, Tensor v_cache,
Tensor page_table, Tensor cu_new_lens, Tensor cu_cached_lens,
NVTE_QKV_Format qkv_format, int h_kv, int d_k, int d_v, int b,
int max_ctx_len, int max_seq_len, int max_pages_per_seq,
bool is_non_paged, cudaStream_t stream) {
if (new_k.has_data() && new_v.has_data() && k_cache.has_data() && v_cache.has_data()) {
if (is_non_paged) {
reindex_kv_cache_kernel<<<16, 256, 0, stream>>>(
reinterpret_cast<dtype *>(k_cache.data.dptr),
reinterpret_cast<dtype *>(v_cache.data.dptr),
reinterpret_cast<int *>(page_table.data.dptr),
reinterpret_cast<int *>(cu_new_lens.data.dptr),
reinterpret_cast<int *>(cu_cached_lens.data.dptr), h_kv, d_k, d_v, b, max_seq_len);
}
copy_to_kv_cache_kernel<<<16, 256, 0, stream>>>(
reinterpret_cast<dtype *>(new_k.data.dptr), reinterpret_cast<dtype *>(new_v.data.dptr),
reinterpret_cast<dtype *>(k_cache.data.dptr), reinterpret_cast<dtype *>(v_cache.data.dptr),
reinterpret_cast<int *>(page_table.data.dptr),
reinterpret_cast<int *>(cu_new_lens.data.dptr),
reinterpret_cast<int *>(cu_cached_lens.data.dptr), qkv_format, h_kv, d_k, d_v, b,
max_ctx_len, max_seq_len, max_pages_per_seq, is_non_paged);
}
}
void copy_to_kv_cache(Tensor new_k, Tensor new_v, Tensor k_cache, Tensor v_cache, Tensor page_table,
Tensor cu_new_lens, Tensor cu_cached_lens, NVTE_QKV_Format qkv_format, int b,
int max_ctx_len, int max_seq_len, int max_pages_per_seq, bool is_non_paged,
cudaStream_t stream) {
int h_kv = new_k.shape()[new_k.dim() - 2];
int d_k = new_k.shape()[new_k.dim() - 1];
int d_v = new_v.shape()[new_v.dim() - 1];
NVTE_CHECK(k_cache.dtype() == v_cache.dtype() && new_k.dtype() == new_v.dtype() &&
new_k.dtype() == k_cache.dtype(),
"new_k, new_v, k_cache and v_cache must be of the same data type.");
NVTE_CHECK(qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD ||
qkv_format == NVTE_QKV_Format::NVTE_THD,
"qkv_format must be {BSHD, SBHD, THD}.");
TRANSFORMER_ENGINE_TYPE_SWITCH_FLOAT(
k_cache.dtype(), dtype,
copy_to_kv_cache_launcher<dtype>(new_k, new_v, k_cache, v_cache, page_table, cu_new_lens,
cu_cached_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len,
max_seq_len, max_pages_per_seq, is_non_paged, stream););
}
template <typename scalar_t>
__global__ void convert_thd_to_bshd_kernel(scalar_t *tensor, scalar_t *new_tensor, int *cu_seqlens,
int b, int max_seq_len, int h, int d) {
// tensor: thd; new_tensor: bshd
// cu_seqlens: [b + 1]
for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
int num_elts = (cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx]) * h * d;
int thd_offset = cu_seqlens[batch_idx] * h * d;
int bshd_offset = batch_idx * max_seq_len * h * d;
scalar_t *thd_token = tensor + thd_offset;
scalar_t *bshd_token = new_tensor + bshd_offset;
for (int i = threadIdx.x; i < num_elts; i += blockDim.x) {
*(bshd_token + i) = *(thd_token + i);
}
}
}
template <typename scalar_t>
void convert_thd_to_bshd_launcher(Tensor tensor, Tensor new_tensor, Tensor cu_seqlens, int b,
int max_seq_len, int h, int d, cudaStream_t stream) {
using namespace transformer_engine;
convert_thd_to_bshd_kernel<<<16, 256, 0, stream>>>(
reinterpret_cast<scalar_t *>(tensor.data.dptr),
reinterpret_cast<scalar_t *>(new_tensor.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), b, max_seq_len, h, d);
}
void convert_thd_to_bshd(Tensor tensor, Tensor cu_seqlens, Tensor new_tensor, int b,
int max_seq_len, cudaStream_t stream) {
using namespace transformer_engine;
auto tensor_shape = tensor.shape();
TRANSFORMER_ENGINE_TYPE_SWITCH_FLOAT(
new_tensor.dtype(), dtype,
convert_thd_to_bshd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len,
tensor_shape[1], tensor_shape[2], stream););
}
template <typename scalar_t>
__global__ void convert_bshd_to_thd_kernel(scalar_t *tensor, scalar_t *new_tensor, int *cu_seqlens,
int b, int max_seq_len, int h, int d) {
// tensor: bshd; new_tensor: thd
// cu_seqlens: [b + 1]
for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
int seqlen = cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx];
int num_elts = seqlen * h * d;
int bshd_offset = batch_idx * max_seq_len * h * d;
int thd_offset = cu_seqlens[batch_idx] * h * d;
scalar_t *bshd_token = tensor + bshd_offset;
scalar_t *thd_token = new_tensor + thd_offset;
for (int i = threadIdx.x; i < num_elts; i += blockDim.x) {
*(thd_token + i) = *(bshd_token + i);
}
}
}
template <typename scalar_t>
void convert_bshd_to_thd_launcher(Tensor tensor, Tensor new_tensor, Tensor cu_seqlens, int b,
int max_seq_len, int h, int d, cudaStream_t stream) {
using namespace transformer_engine;
convert_bshd_to_thd_kernel<<<16, 256, 0, stream>>>(
reinterpret_cast<scalar_t *>(tensor.data.dptr),
reinterpret_cast<scalar_t *>(new_tensor.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), b, max_seq_len, h, d);
}
void convert_bshd_to_thd(Tensor tensor, Tensor cu_seqlens, Tensor new_tensor, int t,
cudaStream_t stream) {
using namespace transformer_engine;
auto tensor_shape = tensor.shape();
TRANSFORMER_ENGINE_TYPE_SWITCH_FLOAT(
tensor.dtype(), dtype,
convert_bshd_to_thd_launcher<dtype>(tensor, new_tensor, cu_seqlens, tensor_shape[0],
tensor_shape[1], tensor_shape[2], tensor_shape[3],
stream););
}
} // namespace kv_cache
} // namespace transformer_engine } // namespace transformer_engine
#endif
/***************************************************************************************************
* KV Cache: Copy new KV tokens to the KV cache
* 1. new_k and new_v are in qkv_format; k_cache and v_cache are in 'bshd' format
* 2. cu_new_lens and cu_cached_lens are in shape [b + 1]; cu_cached_lens include the added lens
* in current step
* 3. Non-paged KV cache is a special case of paged KV cache, with page_table = [b, 1] and
* max_pages_per_seq = 1. We use the same underlying kernel for both non-paged and paged.
* Set is_non_paged = True/False to indicate as such.
* 4. is_non_paged = True also re-indexes the KV cache, e.g. the initial batch indices [0, 3, 1, 2]
* becomes [0, 1, 1, 2]. The page_table = batch_indices.unsqueeze(1) is however unchanged.
* batch_indices_post can be used for monotonical indexing, i.e. [0, 1, 2, 3]. batch_indices is
* preserved for the next layer in the same iteration.
* 5. Only supports same page_table for k_cache and v_cache
* 6. Only pad_between_seqs = False when qkv_format = thd, i.e. there should be no pad tokens
* between sequences in new_k and new_v such as [a a a 0..0 b b 0..0 c 0..0].
**************************************************************************************************/
void nvte_copy_to_kv_cache(NVTETensor new_k, NVTETensor new_v, NVTETensor k_cache,
NVTETensor v_cache, NVTETensor page_table, NVTETensor cu_new_lens,
NVTETensor cu_cached_lens, NVTE_QKV_Format qkv_format, int b,
int max_ctx_len, int max_seq_len, int max_pages_per_seq,
int is_non_paged, cudaStream_t stream) {
NVTE_API_CALL(nvte_copy_to_kv_cache);
using namespace transformer_engine;
kv_cache::copy_to_kv_cache(
*reinterpret_cast<Tensor *>(new_k), *reinterpret_cast<Tensor *>(new_v),
*reinterpret_cast<Tensor *>(k_cache), *reinterpret_cast<Tensor *>(v_cache),
*reinterpret_cast<Tensor *>(page_table), *reinterpret_cast<Tensor *>(cu_new_lens),
*reinterpret_cast<Tensor *>(cu_cached_lens), qkv_format, b, max_ctx_len, max_seq_len,
max_pages_per_seq, is_non_paged, stream);
}
/***************************************************************************************************
* KV Cache: Convert a tensor from qkv_format = thd to qkv_format = bshd
**************************************************************************************************/
void nvte_convert_thd_to_bshd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETensor new_tensor,
int b, int max_seq_len, cudaStream_t stream) {
NVTE_API_CALL(nvte_convert_thd_to_bshd);
using namespace transformer_engine;
kv_cache::convert_thd_to_bshd(*reinterpret_cast<Tensor *>(tensor),
*reinterpret_cast<Tensor *>(cu_seqlens),
*reinterpret_cast<Tensor *>(new_tensor), b, max_seq_len, stream);
}
/***************************************************************************************************
* KV Cache: Convert a tensor from qkv_format = bshd to qkv_format = thd
**************************************************************************************************/
void nvte_convert_bshd_to_thd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETensor new_tensor,
int t, cudaStream_t stream) {
NVTE_API_CALL(nvte_convert_bshd_to_thd);
using namespace transformer_engine;
kv_cache::convert_bshd_to_thd(*reinterpret_cast<Tensor *>(tensor),
*reinterpret_cast<Tensor *>(cu_seqlens),
*reinterpret_cast<Tensor *>(new_tensor), t, stream);
}
...@@ -610,5 +610,27 @@ uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cud ...@@ -610,5 +610,27 @@ uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cud
return hout; return hout;
} }
__global__ void extract_seed_and_offset(int64_t *rng_state_ptr, bool captured, int64_t *seed_ptr,
uint64_t seed_val, int64_t *offset_ptr, uint64_t offset_val,
uint32_t offset_intragraph) {
if (captured) {
rng_state_ptr[0] = *seed_ptr;
rng_state_ptr[1] = static_cast<int64_t>(*offset_ptr + static_cast<int64_t>(offset_intragraph));
} else {
rng_state_ptr[0] = static_cast<int64_t>(seed_val);
rng_state_ptr[1] = static_cast<int64_t>(offset_val);
}
}
} // namespace fused_attn } // namespace fused_attn
} // namespace transformer_engine } // namespace transformer_engine
void nvte_extract_seed_and_offset(int64_t *rng_state_ptr, int captured, int64_t *seed_ptr,
uint64_t seed_val, int64_t *offset_ptr, uint64_t offset_val,
uint32_t offset_intragraph, cudaStream_t stream) {
NVTE_API_CALL(nvte_extract_seed_and_offset);
using namespace transformer_engine;
fused_attn::extract_seed_and_offset<<<1, 1, 0, stream>>>(
rng_state_ptr, captured, seed_ptr, seed_val, offset_ptr, offset_val, offset_intragraph);
}
...@@ -244,7 +244,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -244,7 +244,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S, void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S,
NVTETensor O, NVTETensorPack* Aux_CTX_Tensors, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded,
const NVTETensor rng_state, size_t max_seqlen, bool is_training, const NVTETensor rng_state, size_t max_seqlen, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
...@@ -300,7 +300,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, ...@@ -300,7 +300,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
*/ */
void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO,
const NVTETensor S, NVTETensor dP, const NVTETensor S, NVTETensor dP,
const NVTETensorPack* Aux_CTX_Tensors, NVTETensor dQKV, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV,
NVTETensor dBias, const NVTETensor cu_seqlens, NVTETensor dBias, const NVTETensor cu_seqlens,
const NVTETensor cu_seqlens_padded, size_t max_seqlen, const NVTETensor cu_seqlens_padded, size_t max_seqlen,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
...@@ -368,7 +368,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con ...@@ -368,7 +368,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
*/ */
void nvte_fused_attn_fwd_kvpacked( void nvte_fused_attn_fwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O, const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O,
NVTETensorPack* Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout,
...@@ -429,7 +429,7 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -429,7 +429,7 @@ void nvte_fused_attn_fwd_kvpacked(
*/ */
void nvte_fused_attn_bwd_kvpacked( void nvte_fused_attn_bwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO, const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO,
const NVTETensor S, NVTETensor dP, const NVTETensorPack* Aux_CTX_Tensors, NVTETensor dQ, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ,
NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout,
...@@ -500,7 +500,7 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -500,7 +500,7 @@ void nvte_fused_attn_bwd_kvpacked(
*/ */
void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor Bias, NVTETensor S, NVTETensor O, const NVTETensor Bias, NVTETensor S, NVTETensor O,
NVTETensorPack* Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
const NVTETensor page_table_v, const NVTETensor rng_state, const NVTETensor page_table_v, const NVTETensor rng_state,
...@@ -569,7 +569,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -569,7 +569,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
*/ */
void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP,
const NVTETensorPack* Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK,
NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q, NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q,
...@@ -604,6 +604,51 @@ void nvte_populate_rng_state_async(NVTETensor rng_state_dst, const NVTETensor se ...@@ -604,6 +604,51 @@ void nvte_populate_rng_state_async(NVTETensor rng_state_dst, const NVTETensor se
uint32_t nvte_get_runtime_num_segments(NVTETensor cu_seqlen, NVTETensor workspace, size_t len, uint32_t nvte_get_runtime_num_segments(NVTETensor cu_seqlen, NVTETensor workspace, size_t len,
cudaStream_t stream); cudaStream_t stream);
void nvte_extract_seed_and_offset(int64_t *rng_state_ptr, int captured, int64_t *seed_ptr,
uint64_t seed_val, int64_t *offset_ptr, uint64_t offset_val,
uint32_t offset_intragraph, cudaStream_t stream);
void nvte_copy_to_kv_cache(NVTETensor new_k, NVTETensor new_v, NVTETensor k_cache,
NVTETensor v_cache, NVTETensor page_table, NVTETensor cu_new_lens,
NVTETensor cu_cached_lens, NVTE_QKV_Format qkv_format, int b,
int max_ctx_len, int max_seq_len, int max_pages_per_seq,
int is_non_paged, cudaStream_t stream);
void nvte_cp_thd_read_half_tensor(const NVTETensor &tensor, const NVTETensor &cu_seqlens,
NVTETensor half, int half_idx, cudaStream_t stream);
void nvte_cp_thd_second_half_lse_correction(NVTETensor lse, const NVTETensor &lse_per_step,
const NVTETensor &cu_seqlens, int lse_packed,
cudaStream_t stream);
void nvte_cp_thd_read_second_half_lse(const NVTETensor &lse, const NVTETensor &cu_seqlens,
NVTETensor half_lse, int lse_packed,
int second_half_lse_seqlen, cudaStream_t stream);
void nvte_cp_thd_out_correction(NVTETensor out, const NVTETensor &out_per_step,
const NVTETensor &lse, const NVTETensor &lse_per_step,
const NVTETensor &cu_seqlens, int only_second_half, int lse_packed,
cudaStream_t stream);
void nvte_cp_thd_grad_correction(NVTETensor grad, const NVTETensor &grad_per_step,
const NVTETensor &cu_seqlens, const char *first_half,
const char *second_half, cudaStream_t stream);
void nvte_cp_thd_get_partitioned_indices(const NVTETensor &cu_seqlens, NVTETensor output,
int total_tokens, int world_size, int rank,
cudaStream_t stream);
void nvte_convert_thd_to_bshd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETensor new_tensor,
int b, int max_seq_len, cudaStream_t stream);
void nvte_convert_bshd_to_thd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETensor new_tensor,
int t, cudaStream_t stream);
void nvte_prepare_flash_attn_fwd(NVTETensor qkvi, NVTETensor qkv, cudaStream_t stream);
void nvte_prepare_flash_attn_bwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTETensor qkv,
cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
......
...@@ -218,6 +218,26 @@ std::vector<size_t> getTensorShape(at::Tensor t); ...@@ -218,6 +218,26 @@ std::vector<size_t> getTensorShape(at::Tensor t);
transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid,
const std::string& fp8_recipe); const std::string& fp8_recipe);
inline size_t typeToSize(transformer_engine::DType t) {
switch (t) {
case transformer_engine::DType::kInt64:
return 8;
case transformer_engine::DType::kInt32:
case transformer_engine::DType::kFloat32:
return 4;
case transformer_engine::DType::kInt16:
case transformer_engine::DType::kFloat16:
case transformer_engine::DType::kBFloat16:
return 2;
case transformer_engine::DType::kByte:
case transformer_engine::DType::kFloat8E4M3:
case transformer_engine::DType::kFloat8E5M2:
return 1;
default:
NVTE_ERROR("Invalid type");
}
}
inline at::ScalarType GetATenDType(transformer_engine::DType t) { inline at::ScalarType GetATenDType(transformer_engine::DType t) {
switch (t) { switch (t) {
case transformer_engine::DType::kInt16: case transformer_engine::DType::kInt16:
......
...@@ -72,10 +72,10 @@ at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); ...@@ -72,10 +72,10 @@ at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v);
at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len); at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len);
at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int t); at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int t);
void copy_to_kv_cache(torch::Tensor new_k, torch::Tensor new_v, torch::Tensor k_cache, void copy_to_kv_cache(at::Tensor new_k, at::Tensor new_v, at::Tensor k_cache, at::Tensor v_cache,
torch::Tensor v_cache, torch::Tensor page_table, torch::Tensor cu_new_lens, at::Tensor page_table, at::Tensor cu_new_lens, at::Tensor cu_cached_lens,
torch::Tensor cu_cached_lens, NVTE_QKV_Format kv_format, int b, NVTE_QKV_Format kv_format, int b, int max_ctx_len, int max_seq_len,
int max_ctx_len, int max_seq_len, int max_pages_per_seq, bool is_non_paged); int max_pages_per_seq, bool is_non_paged);
/*************************************************************************************************** /***************************************************************************************************
* GEMM * GEMM
...@@ -392,12 +392,11 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, ...@@ -392,12 +392,11 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
namespace nvshmem_api { namespace nvshmem_api {
void init_nvshmem_backend(c10d::ProcessGroup *process_group); void init_nvshmem_backend(c10d::ProcessGroup *process_group);
torch::Tensor create_nvshmem_tensor(const std::vector<int64_t> &shape, c10::ScalarType dtype); at::Tensor create_nvshmem_tensor(const std::vector<int64_t> &shape, c10::ScalarType dtype);
void nvshmem_send_on_current_stream(torch::Tensor src, torch::Tensor dst, int peer, void nvshmem_send_on_current_stream(at::Tensor src, at::Tensor dst, int peer, at::Tensor signal);
torch::Tensor signal);
void nvshmem_wait_on_current_stream(torch::Tensor signal, const std::string &wait_kind); void nvshmem_wait_on_current_stream(at::Tensor signal, const std::string &wait_kind);
void nvshmem_finalize(); void nvshmem_finalize();
} // namespace nvshmem_api } // namespace nvshmem_api
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment