Unverified Commit 51cd4415 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[C][PyTorch]Make pytorch extensions pure cpp (#1754)



* First pass refactor
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* first pass
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* core compiles
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Include cuda dirs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Compiles
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Move grad outside autocast
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix kv cache
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Address review comments
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Change src file name in cmake
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* move the kernels too
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Move comment
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Move comments around
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* more movement
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* move
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent dfe1a65a
......@@ -8,11 +8,7 @@ from pathlib import Path
import setuptools
from .utils import (
all_files_in_dir,
cuda_archs,
cuda_version,
)
from .utils import all_files_in_dir, cuda_version, get_cuda_include_dirs
def setup_pytorch_extension(
......@@ -30,55 +26,30 @@ def setup_pytorch_extension(
] + all_files_in_dir(extensions_dir)
# Header files
include_dirs = [
common_header_files,
common_header_files / "common",
common_header_files / "common" / "include",
csrc_header_files,
]
include_dirs = get_cuda_include_dirs()
include_dirs.extend(
[
common_header_files,
common_header_files / "common",
common_header_files / "common" / "include",
csrc_header_files,
]
)
# Compiler flags
cxx_flags = [
"-O3",
"-fvisibility=hidden",
]
nvcc_flags = [
"-O3",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
]
cuda_architectures = cuda_archs()
if "70" in cuda_architectures:
nvcc_flags.extend(["-gencode", "arch=compute_70,code=sm_70"])
# Version-dependent CUDA options
try:
version = cuda_version()
except FileNotFoundError:
print("Could not determine CUDA Toolkit version")
print("Could not determine CUDA version")
else:
if version < (12, 0):
raise RuntimeError("Transformer Engine requires CUDA 12.0 or newer")
nvcc_flags.extend(
(
"--threads",
os.getenv("NVTE_BUILD_THREADS_PER_JOB", "1"),
)
)
for arch in cuda_architectures.split(";"):
if arch == "70":
continue # Already handled
nvcc_flags.extend(["-gencode", f"arch=compute_{arch},code=sm_{arch}"])
if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
assert (
......@@ -87,7 +58,6 @@ def setup_pytorch_extension(
mpi_path = Path(os.getenv("MPI_HOME"))
include_dirs.append(mpi_path / "include")
cxx_flags.append("-DNVTE_UB_WITH_MPI")
nvcc_flags.append("-DNVTE_UB_WITH_MPI")
library_dirs = []
libraries = []
......@@ -100,21 +70,17 @@ def setup_pytorch_extension(
library_dirs.append(nvshmem_home / "lib")
libraries.append("nvshmem_host")
cxx_flags.append("-DNVTE_ENABLE_NVSHMEM")
nvcc_flags.append("-DNVTE_ENABLE_NVSHMEM")
# Construct PyTorch CUDA extension
sources = [str(path) for path in sources]
include_dirs = [str(path) for path in include_dirs]
from torch.utils.cpp_extension import CUDAExtension
from torch.utils.cpp_extension import CppExtension
return CUDAExtension(
return CppExtension(
name="transformer_engine_torch",
sources=[str(src) for src in sources],
include_dirs=[str(inc) for inc in include_dirs],
extra_compile_args={
"cxx": cxx_flags,
"nvcc": nvcc_flags,
},
extra_compile_args={"cxx": cxx_flags},
libraries=[str(lib) for lib in libraries],
library_dirs=[str(lib_dir) for lib_dir in library_dirs],
)
......@@ -1628,8 +1628,8 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoP
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
)
if is_training:
out.backward(out_grad)
if is_training:
out.backward(out_grad)
param_names = []
param_names.append("hidden_states.grad")
......@@ -1879,8 +1879,8 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training):
checkpoint_core_attention=False,
core_attention_bias_type=config.attn_bias_type,
)
if is_training:
out.backward(out_grad)
if is_training:
out.backward(out_grad)
if is_training:
return out, (inp[0].grad, inp[1].grad, inp[2].grad)
......@@ -1993,7 +1993,7 @@ def _run_custom_mha_fp8(dtype, config, backend):
mha = Custom_MHA_FP8(config).to(dtype=dtype, device="cuda")
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
out = mha(inp, cu_seqlens, config.max_seqlen_q)
out.backward(out_grad)
out.backward(out_grad)
out = torch.load("out.pt")
dqkv = torch.load("dqkv.pt")
......
......@@ -130,18 +130,20 @@ def dtype_tols(dtype: torch.dtype) -> Dict[str, float]:
def assert_allclose(
l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float, rtol: float = None
l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float = None, rtol: float = None
) -> bool:
"""Ensures two lists are equal."""
assert len(l1) == len(l2), "Unequal number of outputs."
for i, (t1, t2) in enumerate(zip(l1, l2)):
tols = dict(atol=atol)
tols = dtype_tols(t2.dtype)
if rtol is not None:
tols["rtol"] = rtol
if atol is not None:
tols["atol"] = atol
result = torch.allclose(t1, t2, **tols)
if not result:
diff = torch.abs(t1 - t2)
tol = atol + (rtol * torch.abs(t2))
tol = tols["atol"] + (tols["rtol"] * torch.abs(t2))
exceed_mask = diff > tol
if exceed_mask.any():
indices = torch.nonzero(exceed_mask, as_tuple=True)
......
......@@ -66,6 +66,9 @@ list(APPEND transformer_engine_SOURCES
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise.cu
activation/gelu.cu
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
fused_attn/kv_cache.cu
fused_attn/fused_attn_f16_max512_seqlen.cu
fused_attn/fused_attn_f16_arbitrary_seqlen.cu
activation/relu.cu
......@@ -173,6 +176,9 @@ set_source_files_properties(fused_softmax/scaled_masked_softmax.cu
multi_tensor/l2norm.cu
multi_tensor/scale.cu
multi_tensor/sgd.cu
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
fused_attn/kv_cache.cu
PROPERTIES
COMPILE_OPTIONS "--use_fast_math")
option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --use_fast_math option" OFF)
......
......@@ -111,7 +111,7 @@ struct Tensor {
columnwise_scale_inv(nullptr, {1}, DType::kFloat32),
scaling_mode(NVTE_DELAYED_TENSOR_SCALING) {}
int numel() const {
size_t numel() const {
size_t acc = 1;
for (const auto dim : shape()) {
acc *= dim;
......@@ -133,6 +133,14 @@ struct Tensor {
return data.dtype;
}
size_t dim() const {
if (!has_data() && has_columnwise_data()) {
return columnwise_data.shape.size();
} else {
return data.shape.size();
}
}
std::vector<size_t> shape() const {
/* Note: We sometimes experience spurious compiler errors
* (-Wstringop-overflow) from this function. It appears that GCC
......@@ -385,6 +393,33 @@ struct TypeInfo {
NVTE_ERROR("Invalid type."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FLOAT(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
case DType::kFloat32: { \
using type = float; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat16: { \
using type = fp16; \
{ __VA_ARGS__ } \
} break; \
case DType::kBFloat16: { \
using type = bf16; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat8E4M3: { \
using type = fp8e4m3; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat8E5M2: { \
using type = fp8e5m2; \
{ __VA_ARGS__ } \
} break; \
default: \
NVTE_ERROR("Invalid type."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
......
......@@ -3,13 +3,17 @@
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_THD_UTILS_CUH_
#define TRANSFORMER_ENGINE_FUSED_ATTN_THD_UTILS_CUH_
#include <assert.h>
#include <cuda.h>
#include <cuda_bf16.h>
#include "../common.h"
#include "transformer_engine/fused_attn.h"
namespace transformer_engine {
namespace context_parallel {
struct LseCorrectionFunctor {
__forceinline__ __device__ static void run(float *lse, float *half_lse, size_t idx,
size_t half_idx) {
......@@ -49,16 +53,13 @@ struct AddFunctor {
#pragma unroll
for (int i = 0; i < sizeof(float4) / sizeof(dtype); i++) {
p_[i] += p[i];
p_[i] = p_[i] + p[i];
}
reinterpret_cast<float4 *>(token)[idx] = d_;
}
};
namespace transformer_engine {
namespace fused_attn {
/***************************************************************************************************
* Support THD format for Context Parallel: Binary search an array for a target value
**************************************************************************************************/
......@@ -107,6 +108,7 @@ __global__ void thd_partition_indices_kernel(int *output, int *cu_seqlens, int b
/***************************************************************************************************
* Support THD format for Context Parallel: Read the half of a THD tensor
**************************************************************************************************/
__global__ void thd_read_half_tensor_kernel(void *half, void *tensor, int *cu_seqlens, int batch,
int hidden_size_in_bytes, int half_idx,
int dim_size_of_token) {
......@@ -232,7 +234,10 @@ __global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float
dtype *p_per_step = reinterpret_cast<dtype *>(&data_per_step);
dtype *p = reinterpret_cast<dtype *>(&data);
for (int k = 0; k < sizeof(float4) / sizeof(dtype); k++) {
p[k] += (p_per_step[k] == 0 ? 0 : p_per_step[k] * lse_corrected_exp);
p[k] = p[k] +
(p_per_step[k] == static_cast<dtype>(0.f)
? static_cast<dtype>(0.f)
: static_cast<dtype>(static_cast<float>(p_per_step[k]) * lse_corrected_exp));
}
reinterpret_cast<float4 *>(cur_out)[j] = data;
}
......@@ -297,6 +302,442 @@ __global__ void thd_grad_correction_kernel(dtype *grad, dtype *grad_per_step, in
}
}
} // namespace fused_attn
/***************************************************************************************************
* Support THD format for Context Parallel: Read the half of a THD tensor
**************************************************************************************************/
void thd_read_half_tensor(const Tensor &tensor, const Tensor &cu_seqlens, Tensor &half,
int half_idx, cudaStream_t stream) {
using namespace transformer_engine;
NVTE_CHECK(tensor.dim() == 3 || tensor.dim() == 4);
NVTE_CHECK(cu_seqlens.dtype() == DType::kInt32);
auto cu_seqlens_shape = cu_seqlens.shape();
auto tensor_shape = tensor.shape();
NVTE_CHECK(cu_seqlens.dim() == 1);
NVTE_CHECK(cu_seqlens_shape[0] >= 2);
// Shapes of q and dq are [t, h, d], so the dimension of "t" is 0
// Shapes of kv and dkv are [2, t, h, d], so the dimension of "t" is 1
int seq_dim = tensor.dim() == 3 ? 0 : 1;
int batch = cu_seqlens_shape[0] - 1;
int num_heads = tensor_shape[seq_dim + 1];
int dim_per_head = tensor_shape[seq_dim + 2];
int hidden_size_in_bytes = num_heads * dim_per_head * typeToSize(tensor.dtype());
// For 128-bits load/store
NVTE_CHECK(hidden_size_in_bytes % 16 == 0);
// Launch Kernel
constexpr unsigned int block = 256;
unsigned int grid_x = (tensor_shape[seq_dim] / 2 * 32 + block - 1) / block;
unsigned int grid_y = 1;
for (int i = 0; i < seq_dim; i++) {
grid_y *= tensor_shape[i];
}
dim3 grid = {grid_x, grid_y};
thd_read_half_tensor_kernel<<<grid, block, sizeof(int) * (batch + 1), stream>>>(
half.data.dptr, tensor.data.dptr, reinterpret_cast<int *>(cu_seqlens.data.dptr), batch,
hidden_size_in_bytes, half_idx, tensor_shape[seq_dim]);
}
/***************************************************************************************************
* Support THD format for Context Parallel: softmax_lse related operations
**************************************************************************************************/
void thd_second_half_lse_correction(Tensor lse, const Tensor &lse_per_step,
const Tensor &cu_seqlens, bool lse_packed,
cudaStream_t stream) {
using namespace transformer_engine;
NVTE_CHECK(lse.dtype() == DType::kFloat32);
NVTE_CHECK(lse_per_step.dtype() == DType::kFloat32);
NVTE_CHECK(cu_seqlens.dtype() == DType::kInt32);
NVTE_CHECK(cu_seqlens.dim() == 1);
int batch, num_heads, lse_seqlen, second_half_lse_seqlen;
auto cu_seqlens_shape = cu_seqlens.shape();
auto lse_shape = lse.shape();
auto lse_per_step_shape = lse_per_step.shape();
if (lse_packed) {
NVTE_CHECK(lse.dim() == 2);
NVTE_CHECK(lse_per_step.dim() == 2);
batch = cu_seqlens_shape[0] - 1;
num_heads = lse_shape[0];
lse_seqlen = lse_shape[1];
second_half_lse_seqlen = lse_per_step_shape[1];
NVTE_CHECK(lse_per_step_shape[0] == num_heads);
NVTE_CHECK(second_half_lse_seqlen >= lse_seqlen / 2);
} else {
NVTE_CHECK(lse.dim() == 3);
NVTE_CHECK(lse_per_step.dim() == 3);
batch = lse_shape[0];
num_heads = lse_shape[1];
lse_seqlen = lse_shape[2];
second_half_lse_seqlen = lse_per_step_shape[2];
NVTE_CHECK(lse_per_step_shape[0] == batch);
NVTE_CHECK(lse_per_step_shape[1] == num_heads);
NVTE_CHECK(second_half_lse_seqlen == lse_seqlen / 2);
NVTE_CHECK(cu_seqlens_shape[0] == batch + 1);
}
constexpr unsigned int block = 256;
unsigned int grid_x = (lse_seqlen / 2 + block - 1) / block;
unsigned int grid_y = num_heads;
dim3 grid = {grid_x, grid_y};
if (lse_packed) {
thd_lse_kernel<true, LseCorrectionFunctor><<<grid, block, sizeof(int) * (batch + 1), stream>>>(
reinterpret_cast<float *>(lse.data.dptr), reinterpret_cast<float *>(lse_per_step.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, num_heads, lse_seqlen,
second_half_lse_seqlen);
} else {
thd_lse_kernel<false, LseCorrectionFunctor><<<grid, block, sizeof(int) * (batch + 1), stream>>>(
reinterpret_cast<float *>(lse.data.dptr), reinterpret_cast<float *>(lse_per_step.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, num_heads, lse_seqlen,
second_half_lse_seqlen);
}
}
void thd_read_second_half_lse(const Tensor &lse, const Tensor &cu_seqlens, Tensor &half_lse,
bool lse_packed, int second_half_lse_seqlen, cudaStream_t stream) {
using namespace transformer_engine;
NVTE_CHECK(lse.dtype() == DType::kFloat32);
NVTE_CHECK(cu_seqlens.dtype() == DType::kInt32);
NVTE_CHECK(cu_seqlens.dim() == 1);
int batch, num_heads, lse_seqlen;
auto cu_seqlens_shape = cu_seqlens.shape();
auto lse_shape = lse.shape();
if (lse_packed) {
NVTE_CHECK(lse.dim() == 2);
batch = cu_seqlens_shape[0] - 1;
num_heads = lse_shape[0];
lse_seqlen = lse_shape[1];
NVTE_CHECK(second_half_lse_seqlen >= lse_seqlen / 2);
} else {
NVTE_CHECK(lse.dim() == 3);
batch = lse_shape[0];
num_heads = lse_shape[1];
lse_seqlen = lse_shape[2];
NVTE_CHECK(cu_seqlens_shape[0] == batch + 1);
NVTE_CHECK(second_half_lse_seqlen == lse_seqlen / 2);
}
constexpr unsigned int block = 256;
unsigned int grid_x = (lse_seqlen / 2 + block - 1) / block;
unsigned int grid_y = num_heads;
dim3 grid = {grid_x, grid_y};
if (lse_packed) {
thd_lse_kernel<true, ReadLseFunctor><<<grid, block, sizeof(int) * (batch + 1), stream>>>(
reinterpret_cast<float *>(lse.data.dptr), reinterpret_cast<float *>(half_lse.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, num_heads, lse_seqlen,
second_half_lse_seqlen);
} else {
thd_lse_kernel<false, ReadLseFunctor><<<grid, block, sizeof(int) * (batch + 1), stream>>>(
reinterpret_cast<float *>(lse.data.dptr), reinterpret_cast<float *>(half_lse.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, num_heads, lse_seqlen,
second_half_lse_seqlen);
}
}
/***************************************************************************************************
* Support THD format for Context Parallel: Out correction in forward
**************************************************************************************************/
template <typename dtype, int only_second_half>
static void thd_out_correction_helper(Tensor out, const Tensor &out_per_step, const Tensor &lse,
const Tensor &lse_per_step, const Tensor &cu_seqlens,
bool lse_packed, cudaStream_t stream) {
using namespace transformer_engine;
NVTE_CHECK(out.dtype() == out_per_step.dtype());
NVTE_CHECK(lse.dtype() == DType::kFloat32);
NVTE_CHECK(lse_per_step.dtype() == DType::kFloat32);
NVTE_CHECK(cu_seqlens.dtype() == DType::kInt32);
auto out_shape = out.shape();
auto lse_shape = lse.shape();
auto out_per_step_shape = out_per_step.shape();
auto lse_per_step_shape = lse_per_step.shape();
auto cu_seqlens_shape = cu_seqlens.shape();
int total_tokens = out_shape[0];
int num_heads = out_shape[1];
int dim_per_head = out_shape[2];
NVTE_CHECK(out_per_step_shape[0] == total_tokens / (only_second_half + 1));
NVTE_CHECK(out_per_step_shape[1] == num_heads);
NVTE_CHECK(out_per_step_shape[2] == dim_per_head);
int batch, lse_seqlen, lse_per_step_seqlen;
if (lse_packed) {
batch = cu_seqlens_shape[0] - 1;
lse_seqlen = lse_shape[1];
lse_per_step_seqlen = lse_per_step_shape[1];
NVTE_CHECK(lse_shape[0] == num_heads);
NVTE_CHECK(lse_seqlen >= total_tokens);
NVTE_CHECK(lse_per_step_shape[0] == num_heads);
NVTE_CHECK(lse_per_step_seqlen >= lse_seqlen / (only_second_half + 1));
} else {
batch = lse_shape[0];
lse_seqlen = lse_shape[2];
lse_per_step_seqlen = lse_per_step_shape[2];
NVTE_CHECK(lse_shape[1] == num_heads);
NVTE_CHECK(lse_per_step_shape[0] == batch);
NVTE_CHECK(lse_per_step_shape[1] == num_heads);
NVTE_CHECK(lse_per_step_seqlen == lse_seqlen / (only_second_half + 1));
NVTE_CHECK(cu_seqlens_shape[0] == batch + 1);
}
constexpr int tile = 16;
constexpr int block = 512;
unsigned int grid_x =
(static_cast<size_t>(total_tokens) / (only_second_half + 1) * tile + block - 1) / block;
dim3 grid = {grid_x, (unsigned int)num_heads};
if (lse_packed) {
thd_out_correction_kernel<dtype, only_second_half, tile, true>
<<<grid, block, sizeof(int) * (batch + 1), stream>>>(
reinterpret_cast<dtype *>(out.data.dptr),
reinterpret_cast<dtype *>(out_per_step.data.dptr),
reinterpret_cast<float *>(lse.data.dptr),
reinterpret_cast<float *>(lse_per_step.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, num_heads, dim_per_head,
lse_seqlen, lse_per_step_seqlen);
} else {
thd_out_correction_kernel<dtype, only_second_half, tile, false>
<<<grid, block, sizeof(int) * (batch + 1), stream>>>(
reinterpret_cast<dtype *>(out.data.dptr),
reinterpret_cast<dtype *>(out_per_step.data.dptr),
reinterpret_cast<float *>(lse.data.dptr),
reinterpret_cast<float *>(lse_per_step.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, num_heads, dim_per_head,
lse_seqlen, lse_per_step_seqlen);
}
}
void thd_out_correction(Tensor out, const Tensor &out_per_step, const Tensor &lse,
const Tensor &lse_per_step, const Tensor &cu_seqlens, bool only_second_half,
bool lse_packed, cudaStream_t stream) {
using namespace transformer_engine;
if (only_second_half) {
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
out.dtype(), dtype,
thd_out_correction_helper<dtype, 1>(out, out_per_step, lse, lse_per_step, cu_seqlens,
lse_packed, stream););
} else {
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
out.dtype(), dtype,
thd_out_correction_helper<dtype, 0>(out, out_per_step, lse, lse_per_step, cu_seqlens,
lse_packed, stream););
}
}
/***************************************************************************************************
* Support THD format for Context Parallel: Gradients correction in backward
**************************************************************************************************/
template <typename dtype, typename Functor_0, typename Functor_1, int functor_idx>
static void thd_grad_correction_helper(Tensor grad, const Tensor &grad_per_step,
const Tensor &cu_seqlens, cudaStream_t stream) {
using namespace transformer_engine;
NVTE_CHECK(grad.dim() == 3 || grad.dim() == 4);
NVTE_CHECK(cu_seqlens.dtype() == DType::kInt32);
NVTE_CHECK(cu_seqlens.dim() == 1);
auto grad_shape = grad.shape();
auto cu_seqlens_shape = cu_seqlens.shape();
auto grad_per_step_shape = grad_per_step.shape();
// Shape of dq is [t, h, d], so the dimension of "t" is 0
// Shape of dkv is [2, t, h, d], so the dimension of "t" is 1
int seq_dim = grad.dim() == 3 ? 0 : 1;
int total_tokens = grad_shape[seq_dim];
int num_heads = grad_shape[seq_dim + 1];
int dim_per_head = grad_shape[seq_dim + 2];
int batch = cu_seqlens_shape[0] - 1;
if constexpr (functor_idx < 2) {
NVTE_CHECK(grad_per_step_shape[seq_dim] == total_tokens / 2);
} else {
NVTE_CHECK(grad_per_step_shape[seq_dim] == total_tokens);
}
NVTE_CHECK(grad_per_step_shape[seq_dim + 1] == num_heads);
NVTE_CHECK(grad_per_step_shape[seq_dim + 2] == dim_per_head);
size_t hidden_size = num_heads * dim_per_head;
NVTE_CHECK((hidden_size * typeToSize(grad.dtype())) % 16 == 0);
constexpr unsigned int block = 256;
unsigned int grid_x;
if constexpr (functor_idx < 2) {
grid_x = (total_tokens / 2 * 32 + block - 1) / block;
} else {
grid_x = (total_tokens * 32 + block - 1) / block;
}
unsigned int grid_y = 1;
for (int i = 0; i < seq_dim; i++) {
grid_y *= grad_shape[i];
}
dim3 grid = {grid_x, grid_y};
thd_grad_correction_kernel<dtype, Functor_0, Functor_1, functor_idx, 32>
<<<grid, block, sizeof(int) * (batch + 1), stream>>>(
reinterpret_cast<dtype *>(grad.data.dptr),
reinterpret_cast<dtype *>(grad_per_step.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, hidden_size, total_tokens);
}
template <typename dtype>
static void thd_grad_dispatcher(Tensor grad, const Tensor &grad_per_step, const Tensor &cu_seqlens,
const std::string &first_half, const std::string &second_half,
cudaStream_t stream) {
using namespace transformer_engine;
if (first_half == "add" && second_half == "none") {
thd_grad_correction_helper<dtype, AddFunctor<dtype>, EmptyFunctor, 0>(grad, grad_per_step,
cu_seqlens, stream);
} else if (first_half == "copy" && second_half == "none") {
thd_grad_correction_helper<dtype, CopyFunctor, EmptyFunctor, 0>(grad, grad_per_step, cu_seqlens,
stream);
} else if (first_half == "none" && second_half == "add") {
thd_grad_correction_helper<dtype, EmptyFunctor, AddFunctor<dtype>, 1>(grad, grad_per_step,
cu_seqlens, stream);
} else if (first_half == "none" && second_half == "copy") {
thd_grad_correction_helper<dtype, EmptyFunctor, CopyFunctor, 1>(grad, grad_per_step, cu_seqlens,
stream);
} else if (first_half == "add" && second_half == "copy") {
thd_grad_correction_helper<dtype, AddFunctor<dtype>, CopyFunctor, 2>(grad, grad_per_step,
cu_seqlens, stream);
} else if (first_half == "copy" && second_half == "add") {
thd_grad_correction_helper<dtype, CopyFunctor, AddFunctor<dtype>, 2>(grad, grad_per_step,
cu_seqlens, stream);
} else {
NVTE_ERROR("Unsupported Functor of first half and second_half\n");
}
}
void thd_grad_correction(Tensor grad, const Tensor &grad_per_step, const Tensor &cu_seqlens,
const std::string &first_half, const std::string &second_half,
cudaStream_t stream) {
using namespace transformer_engine;
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
grad.dtype(), dtype,
thd_grad_dispatcher<dtype>(grad, grad_per_step, cu_seqlens, first_half, second_half,
stream););
}
/***************************************************************************************************
* Support THD format for Context Parallel: Generate partitioned indices for input tokens
**************************************************************************************************/
void thd_get_partitioned_indices(const Tensor &cu_seqlens, Tensor output, int total_tokens,
int world_size, int rank, cudaStream_t stream) {
using namespace transformer_engine;
NVTE_CHECK(cu_seqlens.dtype() == DType::kInt32);
NVTE_CHECK(cu_seqlens.dim() == 1);
auto cu_seqlens_shape = cu_seqlens.shape();
auto output_shape = output.shape();
NVTE_CHECK(cu_seqlens_shape[0] >= 2);
NVTE_CHECK(rank >= 0 && rank < world_size);
NVTE_CHECK(world_size > 0);
NVTE_CHECK(total_tokens > 0 && total_tokens % (world_size * 2) == 0);
int batch = cu_seqlens_shape[0] - 1;
constexpr unsigned int block = 256;
unsigned int grid = (output_shape[0] + block - 1) / block;
thd_partition_indices_kernel<<<grid, block, sizeof(int) * (batch + 1), stream>>>(
reinterpret_cast<int *>(output.data.dptr), reinterpret_cast<int *>(cu_seqlens.data.dptr),
batch, total_tokens, world_size, rank);
}
} // namespace context_parallel
} // namespace transformer_engine
#endif
void nvte_cp_thd_read_half_tensor(const NVTETensor &tensor, const NVTETensor &cu_seqlens,
NVTETensor half, int half_idx, cudaStream_t stream) {
NVTE_API_CALL(nvte_thd_read_half_tensor);
using namespace transformer_engine;
context_parallel::thd_read_half_tensor(*reinterpret_cast<Tensor *>(tensor),
*reinterpret_cast<Tensor *>(cu_seqlens),
*reinterpret_cast<Tensor *>(half), half_idx, stream);
}
void nvte_cp_thd_second_half_lse_correction(NVTETensor lse, const NVTETensor &lse_per_step,
const NVTETensor &cu_seqlens, int lse_packed,
cudaStream_t stream) {
NVTE_API_CALL(nvte_thd_second_half_lse_correction);
using namespace transformer_engine;
context_parallel::thd_second_half_lse_correction(
*reinterpret_cast<Tensor *>(lse), *reinterpret_cast<Tensor *>(lse_per_step),
*reinterpret_cast<Tensor *>(cu_seqlens), lse_packed, stream);
}
void nvte_cp_thd_read_second_half_lse(const NVTETensor &lse, const NVTETensor &cu_seqlens,
NVTETensor half_lse, int lse_packed,
int second_half_lse_seqlen, cudaStream_t stream) {
NVTE_API_CALL(nvte_thd_read_second_half_lse);
using namespace transformer_engine;
context_parallel::thd_read_second_half_lse(
*reinterpret_cast<Tensor *>(lse), *reinterpret_cast<Tensor *>(cu_seqlens),
*reinterpret_cast<Tensor *>(half_lse), lse_packed, second_half_lse_seqlen, stream);
}
void nvte_cp_thd_out_correction(NVTETensor out, const NVTETensor &out_per_step,
const NVTETensor &lse, const NVTETensor &lse_per_step,
const NVTETensor &cu_seqlens, int only_second_half, int lse_packed,
cudaStream_t stream) {
NVTE_API_CALL(nvte_thd_out_correction);
using namespace transformer_engine;
context_parallel::thd_out_correction(
*reinterpret_cast<Tensor *>(out), *reinterpret_cast<Tensor *>(out_per_step),
*reinterpret_cast<Tensor *>(lse), *reinterpret_cast<Tensor *>(lse_per_step),
*reinterpret_cast<Tensor *>(cu_seqlens), only_second_half, lse_packed, stream);
}
void nvte_cp_thd_grad_correction(NVTETensor grad, const NVTETensor &grad_per_step,
const NVTETensor &cu_seqlens, const char *first_half,
const char *second_half, cudaStream_t stream) {
NVTE_API_CALL(nvte_thd_grad_correction);
using namespace transformer_engine;
std::string first_half_str(first_half);
std::string second_half_str(second_half);
context_parallel::thd_grad_correction(
*reinterpret_cast<Tensor *>(grad), *reinterpret_cast<Tensor *>(grad_per_step),
*reinterpret_cast<Tensor *>(cu_seqlens), first_half_str, second_half_str, stream);
}
void nvte_cp_thd_get_partitioned_indices(const NVTETensor &cu_seqlens, NVTETensor output,
int total_tokens, int world_size, int rank,
cudaStream_t stream) {
NVTE_API_CALL(nvte_thd_get_partitioned_indices);
using namespace transformer_engine;
context_parallel::thd_get_partitioned_indices(*reinterpret_cast<Tensor *>(cu_seqlens),
*reinterpret_cast<Tensor *>(output), total_tokens,
world_size, rank, stream);
}
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "../common.h"
#include "transformer_engine/fused_attn.h"
namespace transformer_engine {
namespace flash_attention {
constexpr int warp_size = 32;
constexpr int type_size = 2; // FP16 or BF16
constexpr int nvec = sizeof(uint64_t) / type_size;
constexpr int load_size = warp_size * nvec;
constexpr int block_size = 512;
template <typename T>
__launch_bounds__(block_size) __global__
void prepare_kernel_fwd(const T *qkvi, T *qkv, const size_t B, const size_t S, const size_t Z,
const size_t W) {
const int warpid = (blockDim.x * blockIdx.x + threadIdx.x) / warp_size;
const int id_in_warp = threadIdx.x % warp_size;
const size_t offset_input = blockIdx.y * W + warpid * 3 * W * Z + id_in_warp * nvec;
const T *my_input = qkvi + offset_input;
const size_t s = warpid / B;
if (s >= S) return;
const size_t b = warpid % B;
const size_t offset_output = blockIdx.y * B * S * Z * W + (s + b * S) * W * Z + id_in_warp * nvec;
T *my_output = qkv + offset_output;
for (int i = 0; i < Z; ++i) {
uint64_t *out = reinterpret_cast<uint64_t *>(my_output + i * load_size);
*out = *reinterpret_cast<const uint64_t *>(my_input + i * load_size * 3);
}
}
template <typename T>
__launch_bounds__(block_size) __global__
void prepare_kernel_bwd(const T *q, const T *k, const T *v, T *qkv, const size_t B,
const size_t S, const size_t Z, const size_t W) {
const T *input = blockIdx.y == 0 ? q : (blockIdx.y == 1 ? k : v);
const int warpid = (blockDim.x * blockIdx.x + threadIdx.x) / warp_size;
const int id_in_warp = threadIdx.x % warp_size;
const size_t offset_input = warpid * W * Z + id_in_warp * nvec;
const T *my_input = input + offset_input;
const size_t b = warpid / S;
if (b >= B) return;
const size_t s = warpid % S;
const size_t offset_output = (b + s * B) * 3 * W * Z + id_in_warp * nvec + blockIdx.y * W;
T *my_output = qkv + offset_output;
for (int i = 0; i < Z; ++i) {
uint64_t *out = reinterpret_cast<uint64_t *>(my_output + i * load_size * 3);
*out = *reinterpret_cast<const uint64_t *>(my_input + i * load_size);
}
}
void prepare_flash_attn_fwd(Tensor qkvi, Tensor qkv, cudaStream_t stream) {
using namespace transformer_engine;
NVTE_CHECK(qkvi.dim() == 4, "Expected 4-dim tensor.");
NVTE_CHECK(qkvi.dtype() == DType::kFloat16 || qkvi.dtype() == DType::kBFloat16);
auto qkvi_shape = qkvi.shape();
NVTE_CHECK(qkvi_shape[3] % load_size == 0);
NVTE_CHECK(qkvi_shape[3] == load_size);
// [s, b, n, h * 3] -> [3, b, s, n, h]
std::vector<uint64_t> shape = {3, qkvi_shape[1], qkvi_shape[0], qkvi_shape[2], qkvi_shape[3]};
size_t warps = qkvi_shape[0] * qkvi_shape[1];
size_t warps_per_block = block_size / warp_size;
size_t blocks = (warps + warps_per_block - 1) / warps_per_block;
dim3 grid(blocks, 3);
int threads = block_size;
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(
qkvi.dtype(), dtype,
prepare_kernel_fwd<dtype><<<grid, threads, 0, stream>>>(
reinterpret_cast<dtype *>(qkvi.data.dptr), reinterpret_cast<dtype *>(qkv.data.dptr),
shape[1], shape[2], shape[3], shape[4]););
}
void prepare_flash_attn_bwd(Tensor q, Tensor k, Tensor v, Tensor qkv, cudaStream_t stream) {
using namespace transformer_engine;
NVTE_CHECK(q.dim() == 4, "Expected 4-dim tensor.");
NVTE_CHECK(k.dim() == 4, "Expected 4-dim tensor.");
NVTE_CHECK(v.dim() == 4, "Expected 4-dim tensor.");
NVTE_CHECK(q.dtype() == DType::kFloat16 || q.dtype() == DType::kBFloat16);
NVTE_CHECK(k.dtype() == q.dtype());
NVTE_CHECK(v.dtype() == q.dtype());
auto q_shape = q.shape();
auto k_shape = k.shape();
auto v_shape = v.shape();
NVTE_CHECK(q_shape[3] % load_size == 0);
NVTE_CHECK(q_shape[3] == load_size);
NVTE_CHECK(k_shape[3] % load_size == 0);
NVTE_CHECK(k_shape[3] == load_size);
NVTE_CHECK(v_shape[3] % load_size == 0);
NVTE_CHECK(v_shape[3] == load_size);
// 3 x [s, b, n, h] -> [b, s, n, 3 * h]
std::vector<uint64_t> shape = {q_shape[1], q_shape[0], q_shape[2], 3 * q_shape[3]};
size_t warps = q_shape[0] * q_shape[1];
size_t warps_per_block = block_size / warp_size;
size_t blocks = (warps + warps_per_block - 1) / warps_per_block;
dim3 grid(blocks, 3);
int threads = block_size;
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(
q.dtype(), dtype,
prepare_kernel_bwd<dtype><<<grid, threads, 0, stream>>>(
reinterpret_cast<dtype *>(q.data.dptr), reinterpret_cast<dtype *>(k.data.dptr),
reinterpret_cast<dtype *>(v.data.dptr), reinterpret_cast<dtype *>(qkv.data.dptr),
q_shape[0], q_shape[1], q_shape[2], q_shape[3]););
}
} // namespace flash_attention
} // namespace transformer_engine
void nvte_prepare_flash_attn_fwd(NVTETensor qkvi, NVTETensor qkv, cudaStream_t stream) {
NVTE_API_CALL(nvte_prepare_flash_attn_fwd);
using namespace transformer_engine;
flash_attention::prepare_flash_attn_fwd(*reinterpret_cast<Tensor *>(qkvi),
*reinterpret_cast<Tensor *>(qkv), stream);
}
void nvte_prepare_flash_attn_bwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTETensor qkv,
cudaStream_t stream) {
NVTE_API_CALL(nvte_prepare_flash_attn_bwd);
using namespace transformer_engine;
flash_attention::prepare_flash_attn_bwd(
*reinterpret_cast<Tensor *>(q), *reinterpret_cast<Tensor *>(k),
*reinterpret_cast<Tensor *>(v), *reinterpret_cast<Tensor *>(qkv), stream);
}
......@@ -3,48 +3,15 @@
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_KV_CACHE_CUH_
#define TRANSFORMER_ENGINE_FUSED_ATTN_KV_CACHE_CUH_
namespace transformer_engine {
namespace fused_attn {
template <typename scalar_t>
__global__ void convert_thd_to_bshd_kernel(scalar_t *tensor, scalar_t *new_tensor, int *cu_seqlens,
int b, int max_seq_len, int h, int d) {
// tensor: thd; new_tensor: bshd
// cu_seqlens: [b + 1]
for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
int num_elts = (cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx]) * h * d;
int thd_offset = cu_seqlens[batch_idx] * h * d;
int bshd_offset = batch_idx * max_seq_len * h * d;
scalar_t *thd_token = tensor + thd_offset;
scalar_t *bshd_token = new_tensor + bshd_offset;
for (int i = threadIdx.x; i < num_elts; i += blockDim.x) {
*(bshd_token + i) = *(thd_token + i);
}
}
}
#include "../common.h"
#include "transformer_engine/fused_attn.h"
template <typename scalar_t>
__global__ void convert_bshd_to_thd_kernel(scalar_t *tensor, scalar_t *new_tensor, int *cu_seqlens,
int b, int max_seq_len, int h, int d) {
// tensor: bshd; new_tensor: thd
// cu_seqlens: [b + 1]
for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
int seqlen = cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx];
int num_elts = seqlen * h * d;
int bshd_offset = batch_idx * max_seq_len * h * d;
int thd_offset = cu_seqlens[batch_idx] * h * d;
scalar_t *bshd_token = tensor + bshd_offset;
scalar_t *thd_token = new_tensor + thd_offset;
for (int i = threadIdx.x; i < num_elts; i += blockDim.x) {
*(thd_token + i) = *(bshd_token + i);
}
}
}
namespace transformer_engine {
namespace kv_cache {
template <typename scalar_t>
__global__ void reindex_kv_cache_kernel(scalar_t *k_cache, scalar_t *v_cache, int *batch_indices,
template <typename dtype>
__global__ void reindex_kv_cache_kernel(dtype *k_cache, dtype *v_cache, int *batch_indices,
int *cu_new_lens, int *cu_cached_lens, int h_kv, int d_k,
int d_v, int b, int max_seq_len) {
// k_cache, v_cache: bshd
......@@ -75,11 +42,11 @@ __global__ void reindex_kv_cache_kernel(scalar_t *k_cache, scalar_t *v_cache, in
}
}
template <typename scalar_t>
__global__ void copy_to_kv_cache_kernel(scalar_t *new_k, scalar_t *new_v, scalar_t *k_cache,
scalar_t *v_cache, int *page_table, int *cu_new_lens,
int *cu_cached_lens, NVTE_QKV_Format qkv_format, int h_kv,
int d_k, int d_v, int b, int max_ctx_len, int max_seq_len,
template <typename dtype>
__global__ void copy_to_kv_cache_kernel(dtype *new_k, dtype *new_v, dtype *k_cache, dtype *v_cache,
int *page_table, int *cu_new_lens, int *cu_cached_lens,
NVTE_QKV_Format qkv_format, int h_kv, int d_k, int d_v,
int b, int max_ctx_len, int max_seq_len,
int max_pages_per_seq, bool is_non_paged) {
// new_k, new_v: qkv_format; k_cache, v_cache: bshd
// cu_new_lens, cu_cached_lens: [b + 1]
......@@ -140,6 +107,191 @@ __global__ void copy_to_kv_cache_kernel(scalar_t *new_k, scalar_t *new_v, scalar
}
}
}
} // namespace fused_attn
template <typename dtype>
void copy_to_kv_cache_launcher(Tensor new_k, Tensor new_v, Tensor k_cache, Tensor v_cache,
Tensor page_table, Tensor cu_new_lens, Tensor cu_cached_lens,
NVTE_QKV_Format qkv_format, int h_kv, int d_k, int d_v, int b,
int max_ctx_len, int max_seq_len, int max_pages_per_seq,
bool is_non_paged, cudaStream_t stream) {
if (new_k.has_data() && new_v.has_data() && k_cache.has_data() && v_cache.has_data()) {
if (is_non_paged) {
reindex_kv_cache_kernel<<<16, 256, 0, stream>>>(
reinterpret_cast<dtype *>(k_cache.data.dptr),
reinterpret_cast<dtype *>(v_cache.data.dptr),
reinterpret_cast<int *>(page_table.data.dptr),
reinterpret_cast<int *>(cu_new_lens.data.dptr),
reinterpret_cast<int *>(cu_cached_lens.data.dptr), h_kv, d_k, d_v, b, max_seq_len);
}
copy_to_kv_cache_kernel<<<16, 256, 0, stream>>>(
reinterpret_cast<dtype *>(new_k.data.dptr), reinterpret_cast<dtype *>(new_v.data.dptr),
reinterpret_cast<dtype *>(k_cache.data.dptr), reinterpret_cast<dtype *>(v_cache.data.dptr),
reinterpret_cast<int *>(page_table.data.dptr),
reinterpret_cast<int *>(cu_new_lens.data.dptr),
reinterpret_cast<int *>(cu_cached_lens.data.dptr), qkv_format, h_kv, d_k, d_v, b,
max_ctx_len, max_seq_len, max_pages_per_seq, is_non_paged);
}
}
void copy_to_kv_cache(Tensor new_k, Tensor new_v, Tensor k_cache, Tensor v_cache, Tensor page_table,
Tensor cu_new_lens, Tensor cu_cached_lens, NVTE_QKV_Format qkv_format, int b,
int max_ctx_len, int max_seq_len, int max_pages_per_seq, bool is_non_paged,
cudaStream_t stream) {
int h_kv = new_k.shape()[new_k.dim() - 2];
int d_k = new_k.shape()[new_k.dim() - 1];
int d_v = new_v.shape()[new_v.dim() - 1];
NVTE_CHECK(k_cache.dtype() == v_cache.dtype() && new_k.dtype() == new_v.dtype() &&
new_k.dtype() == k_cache.dtype(),
"new_k, new_v, k_cache and v_cache must be of the same data type.");
NVTE_CHECK(qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD ||
qkv_format == NVTE_QKV_Format::NVTE_THD,
"qkv_format must be {BSHD, SBHD, THD}.");
TRANSFORMER_ENGINE_TYPE_SWITCH_FLOAT(
k_cache.dtype(), dtype,
copy_to_kv_cache_launcher<dtype>(new_k, new_v, k_cache, v_cache, page_table, cu_new_lens,
cu_cached_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len,
max_seq_len, max_pages_per_seq, is_non_paged, stream););
}
template <typename scalar_t>
__global__ void convert_thd_to_bshd_kernel(scalar_t *tensor, scalar_t *new_tensor, int *cu_seqlens,
int b, int max_seq_len, int h, int d) {
// tensor: thd; new_tensor: bshd
// cu_seqlens: [b + 1]
for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
int num_elts = (cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx]) * h * d;
int thd_offset = cu_seqlens[batch_idx] * h * d;
int bshd_offset = batch_idx * max_seq_len * h * d;
scalar_t *thd_token = tensor + thd_offset;
scalar_t *bshd_token = new_tensor + bshd_offset;
for (int i = threadIdx.x; i < num_elts; i += blockDim.x) {
*(bshd_token + i) = *(thd_token + i);
}
}
}
template <typename scalar_t>
void convert_thd_to_bshd_launcher(Tensor tensor, Tensor new_tensor, Tensor cu_seqlens, int b,
int max_seq_len, int h, int d, cudaStream_t stream) {
using namespace transformer_engine;
convert_thd_to_bshd_kernel<<<16, 256, 0, stream>>>(
reinterpret_cast<scalar_t *>(tensor.data.dptr),
reinterpret_cast<scalar_t *>(new_tensor.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), b, max_seq_len, h, d);
}
void convert_thd_to_bshd(Tensor tensor, Tensor cu_seqlens, Tensor new_tensor, int b,
int max_seq_len, cudaStream_t stream) {
using namespace transformer_engine;
auto tensor_shape = tensor.shape();
TRANSFORMER_ENGINE_TYPE_SWITCH_FLOAT(
new_tensor.dtype(), dtype,
convert_thd_to_bshd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len,
tensor_shape[1], tensor_shape[2], stream););
}
template <typename scalar_t>
__global__ void convert_bshd_to_thd_kernel(scalar_t *tensor, scalar_t *new_tensor, int *cu_seqlens,
int b, int max_seq_len, int h, int d) {
// tensor: bshd; new_tensor: thd
// cu_seqlens: [b + 1]
for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
int seqlen = cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx];
int num_elts = seqlen * h * d;
int bshd_offset = batch_idx * max_seq_len * h * d;
int thd_offset = cu_seqlens[batch_idx] * h * d;
scalar_t *bshd_token = tensor + bshd_offset;
scalar_t *thd_token = new_tensor + thd_offset;
for (int i = threadIdx.x; i < num_elts; i += blockDim.x) {
*(thd_token + i) = *(bshd_token + i);
}
}
}
template <typename scalar_t>
void convert_bshd_to_thd_launcher(Tensor tensor, Tensor new_tensor, Tensor cu_seqlens, int b,
int max_seq_len, int h, int d, cudaStream_t stream) {
using namespace transformer_engine;
convert_bshd_to_thd_kernel<<<16, 256, 0, stream>>>(
reinterpret_cast<scalar_t *>(tensor.data.dptr),
reinterpret_cast<scalar_t *>(new_tensor.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), b, max_seq_len, h, d);
}
void convert_bshd_to_thd(Tensor tensor, Tensor cu_seqlens, Tensor new_tensor, int t,
cudaStream_t stream) {
using namespace transformer_engine;
auto tensor_shape = tensor.shape();
TRANSFORMER_ENGINE_TYPE_SWITCH_FLOAT(
tensor.dtype(), dtype,
convert_bshd_to_thd_launcher<dtype>(tensor, new_tensor, cu_seqlens, tensor_shape[0],
tensor_shape[1], tensor_shape[2], tensor_shape[3],
stream););
}
} // namespace kv_cache
} // namespace transformer_engine
#endif
/***************************************************************************************************
* KV Cache: Copy new KV tokens to the KV cache
* 1. new_k and new_v are in qkv_format; k_cache and v_cache are in 'bshd' format
* 2. cu_new_lens and cu_cached_lens are in shape [b + 1]; cu_cached_lens include the added lens
* in current step
* 3. Non-paged KV cache is a special case of paged KV cache, with page_table = [b, 1] and
* max_pages_per_seq = 1. We use the same underlying kernel for both non-paged and paged.
* Set is_non_paged = True/False to indicate as such.
* 4. is_non_paged = True also re-indexes the KV cache, e.g. the initial batch indices [0, 3, 1, 2]
* becomes [0, 1, 1, 2]. The page_table = batch_indices.unsqueeze(1) is however unchanged.
* batch_indices_post can be used for monotonical indexing, i.e. [0, 1, 2, 3]. batch_indices is
* preserved for the next layer in the same iteration.
* 5. Only supports same page_table for k_cache and v_cache
* 6. Only pad_between_seqs = False when qkv_format = thd, i.e. there should be no pad tokens
* between sequences in new_k and new_v such as [a a a 0..0 b b 0..0 c 0..0].
**************************************************************************************************/
void nvte_copy_to_kv_cache(NVTETensor new_k, NVTETensor new_v, NVTETensor k_cache,
NVTETensor v_cache, NVTETensor page_table, NVTETensor cu_new_lens,
NVTETensor cu_cached_lens, NVTE_QKV_Format qkv_format, int b,
int max_ctx_len, int max_seq_len, int max_pages_per_seq,
int is_non_paged, cudaStream_t stream) {
NVTE_API_CALL(nvte_copy_to_kv_cache);
using namespace transformer_engine;
kv_cache::copy_to_kv_cache(
*reinterpret_cast<Tensor *>(new_k), *reinterpret_cast<Tensor *>(new_v),
*reinterpret_cast<Tensor *>(k_cache), *reinterpret_cast<Tensor *>(v_cache),
*reinterpret_cast<Tensor *>(page_table), *reinterpret_cast<Tensor *>(cu_new_lens),
*reinterpret_cast<Tensor *>(cu_cached_lens), qkv_format, b, max_ctx_len, max_seq_len,
max_pages_per_seq, is_non_paged, stream);
}
/***************************************************************************************************
* KV Cache: Convert a tensor from qkv_format = thd to qkv_format = bshd
**************************************************************************************************/
void nvte_convert_thd_to_bshd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETensor new_tensor,
int b, int max_seq_len, cudaStream_t stream) {
NVTE_API_CALL(nvte_convert_thd_to_bshd);
using namespace transformer_engine;
kv_cache::convert_thd_to_bshd(*reinterpret_cast<Tensor *>(tensor),
*reinterpret_cast<Tensor *>(cu_seqlens),
*reinterpret_cast<Tensor *>(new_tensor), b, max_seq_len, stream);
}
/***************************************************************************************************
* KV Cache: Convert a tensor from qkv_format = bshd to qkv_format = thd
**************************************************************************************************/
void nvte_convert_bshd_to_thd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETensor new_tensor,
int t, cudaStream_t stream) {
NVTE_API_CALL(nvte_convert_bshd_to_thd);
using namespace transformer_engine;
kv_cache::convert_bshd_to_thd(*reinterpret_cast<Tensor *>(tensor),
*reinterpret_cast<Tensor *>(cu_seqlens),
*reinterpret_cast<Tensor *>(new_tensor), t, stream);
}
......@@ -610,5 +610,27 @@ uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cud
return hout;
}
__global__ void extract_seed_and_offset(int64_t *rng_state_ptr, bool captured, int64_t *seed_ptr,
uint64_t seed_val, int64_t *offset_ptr, uint64_t offset_val,
uint32_t offset_intragraph) {
if (captured) {
rng_state_ptr[0] = *seed_ptr;
rng_state_ptr[1] = static_cast<int64_t>(*offset_ptr + static_cast<int64_t>(offset_intragraph));
} else {
rng_state_ptr[0] = static_cast<int64_t>(seed_val);
rng_state_ptr[1] = static_cast<int64_t>(offset_val);
}
}
} // namespace fused_attn
} // namespace transformer_engine
void nvte_extract_seed_and_offset(int64_t *rng_state_ptr, int captured, int64_t *seed_ptr,
uint64_t seed_val, int64_t *offset_ptr, uint64_t offset_val,
uint32_t offset_intragraph, cudaStream_t stream) {
NVTE_API_CALL(nvte_extract_seed_and_offset);
using namespace transformer_engine;
fused_attn::extract_seed_and_offset<<<1, 1, 0, stream>>>(
rng_state_ptr, captured, seed_ptr, seed_val, offset_ptr, offset_val, offset_intragraph);
}
......@@ -244,7 +244,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S,
NVTETensor O, NVTETensorPack* Aux_CTX_Tensors,
NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded,
const NVTETensor rng_state, size_t max_seqlen, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
......@@ -300,7 +300,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
*/
void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO,
const NVTETensor S, NVTETensor dP,
const NVTETensorPack* Aux_CTX_Tensors, NVTETensor dQKV,
const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV,
NVTETensor dBias, const NVTETensor cu_seqlens,
const NVTETensor cu_seqlens_padded, size_t max_seqlen,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
......@@ -368,7 +368,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
*/
void nvte_fused_attn_fwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O,
NVTETensorPack* Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout,
......@@ -429,7 +429,7 @@ void nvte_fused_attn_fwd_kvpacked(
*/
void nvte_fused_attn_bwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO,
const NVTETensor S, NVTETensor dP, const NVTETensorPack* Aux_CTX_Tensors, NVTETensor dQ,
const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ,
NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout,
......@@ -500,7 +500,7 @@ void nvte_fused_attn_bwd_kvpacked(
*/
void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor Bias, NVTETensor S, NVTETensor O,
NVTETensorPack* Aux_CTX_Tensors, const NVTETensor cu_seqlens_q,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
const NVTETensor page_table_v, const NVTETensor rng_state,
......@@ -569,7 +569,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
*/
void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP,
const NVTETensorPack* Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK,
const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK,
NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q,
......@@ -604,6 +604,51 @@ void nvte_populate_rng_state_async(NVTETensor rng_state_dst, const NVTETensor se
uint32_t nvte_get_runtime_num_segments(NVTETensor cu_seqlen, NVTETensor workspace, size_t len,
cudaStream_t stream);
void nvte_extract_seed_and_offset(int64_t *rng_state_ptr, int captured, int64_t *seed_ptr,
uint64_t seed_val, int64_t *offset_ptr, uint64_t offset_val,
uint32_t offset_intragraph, cudaStream_t stream);
void nvte_copy_to_kv_cache(NVTETensor new_k, NVTETensor new_v, NVTETensor k_cache,
NVTETensor v_cache, NVTETensor page_table, NVTETensor cu_new_lens,
NVTETensor cu_cached_lens, NVTE_QKV_Format qkv_format, int b,
int max_ctx_len, int max_seq_len, int max_pages_per_seq,
int is_non_paged, cudaStream_t stream);
void nvte_cp_thd_read_half_tensor(const NVTETensor &tensor, const NVTETensor &cu_seqlens,
NVTETensor half, int half_idx, cudaStream_t stream);
void nvte_cp_thd_second_half_lse_correction(NVTETensor lse, const NVTETensor &lse_per_step,
const NVTETensor &cu_seqlens, int lse_packed,
cudaStream_t stream);
void nvte_cp_thd_read_second_half_lse(const NVTETensor &lse, const NVTETensor &cu_seqlens,
NVTETensor half_lse, int lse_packed,
int second_half_lse_seqlen, cudaStream_t stream);
void nvte_cp_thd_out_correction(NVTETensor out, const NVTETensor &out_per_step,
const NVTETensor &lse, const NVTETensor &lse_per_step,
const NVTETensor &cu_seqlens, int only_second_half, int lse_packed,
cudaStream_t stream);
void nvte_cp_thd_grad_correction(NVTETensor grad, const NVTETensor &grad_per_step,
const NVTETensor &cu_seqlens, const char *first_half,
const char *second_half, cudaStream_t stream);
void nvte_cp_thd_get_partitioned_indices(const NVTETensor &cu_seqlens, NVTETensor output,
int total_tokens, int world_size, int rank,
cudaStream_t stream);
void nvte_convert_thd_to_bshd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETensor new_tensor,
int b, int max_seq_len, cudaStream_t stream);
void nvte_convert_bshd_to_thd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETensor new_tensor,
int t, cudaStream_t stream);
void nvte_prepare_flash_attn_fwd(NVTETensor qkvi, NVTETensor qkv, cudaStream_t stream);
void nvte_prepare_flash_attn_bwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTETensor qkv,
cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
......
......@@ -218,6 +218,26 @@ std::vector<size_t> getTensorShape(at::Tensor t);
transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid,
const std::string& fp8_recipe);
inline size_t typeToSize(transformer_engine::DType t) {
switch (t) {
case transformer_engine::DType::kInt64:
return 8;
case transformer_engine::DType::kInt32:
case transformer_engine::DType::kFloat32:
return 4;
case transformer_engine::DType::kInt16:
case transformer_engine::DType::kFloat16:
case transformer_engine::DType::kBFloat16:
return 2;
case transformer_engine::DType::kByte:
case transformer_engine::DType::kFloat8E4M3:
case transformer_engine::DType::kFloat8E5M2:
return 1;
default:
NVTE_ERROR("Invalid type");
}
}
inline at::ScalarType GetATenDType(transformer_engine::DType t) {
switch (t) {
case transformer_engine::DType::kInt16:
......
......@@ -72,10 +72,10 @@ at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v);
at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len);
at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int t);
void copy_to_kv_cache(torch::Tensor new_k, torch::Tensor new_v, torch::Tensor k_cache,
torch::Tensor v_cache, torch::Tensor page_table, torch::Tensor cu_new_lens,
torch::Tensor cu_cached_lens, NVTE_QKV_Format kv_format, int b,
int max_ctx_len, int max_seq_len, int max_pages_per_seq, bool is_non_paged);
void copy_to_kv_cache(at::Tensor new_k, at::Tensor new_v, at::Tensor k_cache, at::Tensor v_cache,
at::Tensor page_table, at::Tensor cu_new_lens, at::Tensor cu_cached_lens,
NVTE_QKV_Format kv_format, int b, int max_ctx_len, int max_seq_len,
int max_pages_per_seq, bool is_non_paged);
/***************************************************************************************************
* GEMM
......@@ -392,12 +392,11 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
namespace nvshmem_api {
void init_nvshmem_backend(c10d::ProcessGroup *process_group);
torch::Tensor create_nvshmem_tensor(const std::vector<int64_t> &shape, c10::ScalarType dtype);
at::Tensor create_nvshmem_tensor(const std::vector<int64_t> &shape, c10::ScalarType dtype);
void nvshmem_send_on_current_stream(torch::Tensor src, torch::Tensor dst, int peer,
torch::Tensor signal);
void nvshmem_send_on_current_stream(at::Tensor src, at::Tensor dst, int peer, at::Tensor signal);
void nvshmem_wait_on_current_stream(torch::Tensor signal, const std::string &wait_kind);
void nvshmem_wait_on_current_stream(at::Tensor signal, const std::string &wait_kind);
void nvshmem_finalize();
} // namespace nvshmem_api
......
......@@ -5,12 +5,8 @@
************************************************************************/
#include "extensions.h"
#include "kv_cache.cuh"
#include "thd_utils.cuh"
#include "transformer_engine/transformer_engine.h"
constexpr int block_size = 512;
constexpr int ctas_per_sm = 4;
// get the fused attention backend
NVTE_Fused_Attn_Backend get_fused_attn_backend(
......@@ -26,19 +22,6 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend(
return fused_attention_backend;
}
// fast zero-fills of tensors
template <typename scalar_t>
__global__ void __launch_bounds__(block_size)
mha_fill_kernel(scalar_t *out_tensor, const int32_t *const start_row, const size_t num_rows) {
size_t row_stride = gridDim.y * blockDim.x;
size_t row_index = blockIdx.x + static_cast<size_t>(start_row[0]);
size_t col_index = blockIdx.y * blockDim.x + threadIdx.x;
while (row_index < num_rows) {
out_tensor[row_index * row_stride + col_index] = 0;
row_index += gridDim.x;
}
}
// fast zero-fills of tensors
void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &start_index) {
std::vector<size_t> shape = transformer_engine::pytorch::convertShape(self.shape());
......@@ -48,33 +31,23 @@ void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &s
for (int i = 1; i <= shape.size(); i++) {
fcd_size *= shape[i];
}
TORCH_CHECK(fcd_size % block_size == 0, "input size not aligned to block size");
const int num_mp = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
uint64_t num_blk_y = (uint64_t)(fcd_size / block_size);
uint64_t num_blk_x = (uint64_t)((num_mp * ctas_per_sm + num_blk_y - 1) / num_blk_y);
dim3 dim_grid(num_blk_x, num_blk_y);
dim3 dim_block(block_size);
// trzeba jakos przekonwertowac DType na scalar_type
at::ScalarType scalar_type = transformer_engine::pytorch::GetATenDType(self.dtype());
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16, scalar_type, "mha_fill", [&]() {
mha_fill_kernel<<<dim_grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
static_cast<scalar_t *>(self.get_rowwise_data().data_ptr),
static_cast<int32_t *>(start_index.data_ptr()), max_tokens);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
NVTE_CHECK(fcd_size % block_size == 0, "input size not aligned to block size");
size_t element_size = transformer_engine::pytorch::typeToSize(self.dtype());
int32_t start_row = start_index.data_ptr<int32_t>()[0];
void *base_ptr = static_cast<char *>(self.get_rowwise_data().data_ptr) +
static_cast<size_t>(start_row) * fcd_size * element_size;
size_t num_rows_to_zero = max_tokens - start_row;
size_t total_bytes = num_rows_to_zero * fcd_size * element_size;
nvte_memset(base_ptr, 0, total_bytes, at::cuda::getCurrentCUDAStream());
}
// extract seed and offset from PhiloxCudaState
__global__ void unpack(at::PhiloxCudaState arg, int64_t *rng_state_ptr) {
if (arg.captured_) {
rng_state_ptr[0] = static_cast<int64_t>(*arg.seed_.ptr);
rng_state_ptr[1] =
static_cast<int64_t>(*(arg.offset_.ptr) + static_cast<int64_t>(arg.offset_intragraph_));
} else {
rng_state_ptr[0] = static_cast<int64_t>(arg.seed_.val);
rng_state_ptr[1] = static_cast<int64_t>(arg.offset_.val);
}
void unpack(at::PhiloxCudaState arg, int64_t *rng_state_ptr) {
nvte_extract_seed_and_offset(rng_state_ptr, arg.captured_, arg.seed_.ptr, arg.seed_.val,
arg.offset_.ptr, arg.offset_.val, arg.offset_intragraph_,
at::cuda::getCurrentCUDAStream());
}
// extract PhiloxCudaState from CUDA random number generator
......@@ -193,8 +166,7 @@ std::vector<py::object> fused_attn_fwd(
rng_gen, at::cuda::detail::getDefaultCUDAGenerator());
at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread);
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
unpack<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
philox_args, static_cast<int64_t *>(rng_state.data_ptr()));
unpack(philox_args, static_cast<int64_t *>(rng_state.data_ptr()));
auto te_rng_state = makeTransformerEngineTensor(rng_state);
// create auxiliary output tensors
......@@ -512,72 +484,13 @@ std::vector<py::object> fused_attn_bwd(
return {py_dQ, py_dK, py_dV, py::cast(dBias)};
}
namespace flash_attention {
constexpr int warp_size = 32;
constexpr int type_size = 2; // FP16 or BF16
constexpr int nvec = sizeof(uint64_t) / type_size;
constexpr int load_size = warp_size * nvec;
constexpr int block_size = 512;
template <typename T>
__launch_bounds__(block_size) __global__
void prepare_kernel_fwd(const T *qkvi, T *qkv, const size_t B, const size_t S, const size_t Z,
const size_t W) {
const int warpid = (blockDim.x * blockIdx.x + threadIdx.x) / warp_size;
const int id_in_warp = threadIdx.x % warp_size;
const size_t offset_input = blockIdx.y * W + warpid * 3 * W * Z + id_in_warp * nvec;
const T *my_input = qkvi + offset_input;
const size_t s = warpid / B;
if (s >= S) return;
const size_t b = warpid % B;
const size_t offset_output = blockIdx.y * B * S * Z * W + (s + b * S) * W * Z + id_in_warp * nvec;
T *my_output = qkv + offset_output;
for (int i = 0; i < Z; ++i) {
uint64_t *out = reinterpret_cast<uint64_t *>(my_output + i * load_size);
*out = *reinterpret_cast<const uint64_t *>(my_input + i * load_size * 3);
}
}
template <typename T>
__launch_bounds__(block_size) __global__
void prepare_kernel_bwd(const T *q, const T *k, const T *v, T *qkv, const size_t B,
const size_t S, const size_t Z, const size_t W) {
const T *input = blockIdx.y == 0 ? q : (blockIdx.y == 1 ? k : v);
const int warpid = (blockDim.x * blockIdx.x + threadIdx.x) / warp_size;
const int id_in_warp = threadIdx.x % warp_size;
const size_t offset_input = warpid * W * Z + id_in_warp * nvec;
const T *my_input = input + offset_input;
const size_t b = warpid / S;
if (b >= B) return;
const size_t s = warpid % S;
const size_t offset_output = (b + s * B) * 3 * W * Z + id_in_warp * nvec + blockIdx.y * W;
T *my_output = qkv + offset_output;
for (int i = 0; i < Z; ++i) {
uint64_t *out = reinterpret_cast<uint64_t *>(my_output + i * load_size * 3);
*out = *reinterpret_cast<const uint64_t *>(my_input + i * load_size);
}
}
} // namespace flash_attention
at::Tensor fa_prepare_fwd(at::Tensor qkvi) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
NVTE_CHECK(qkvi.dim() == 4, "Expected 4-dim tensor.");
NVTE_CHECK(qkvi.scalar_type() == at::ScalarType::Half ||
qkvi.scalar_type() == at::ScalarType::BFloat16);
NVTE_CHECK(qkvi.size(3) % flash_attention::load_size == 0);
NVTE_CHECK(qkvi.size(3) == flash_attention::load_size);
NVTE_CHECK(qkvi.stride(3) == 1, "Wrong stride.");
NVTE_CHECK(qkvi.stride(2) == 3 * qkvi.size(3), "Wrong stride.");
NVTE_CHECK(qkvi.stride(1) == 3 * qkvi.size(3) * qkvi.size(2), "Wrong stride.");
......@@ -587,27 +500,18 @@ at::Tensor fa_prepare_fwd(at::Tensor qkvi) {
std::vector<int64_t> shape = {3, qkvi.size(1), qkvi.size(0), qkvi.size(2), qkvi.size(3)};
at::Tensor qkv = at::empty(shape, at::CUDA(qkvi.scalar_type()));
size_t warps = qkvi.size(0) * qkvi.size(1);
size_t warps_per_block = flash_attention::block_size / flash_attention::warp_size;
size_t blocks = (warps + warps_per_block - 1) / warps_per_block;
dim3 grid(blocks, 3);
int threads = flash_attention::block_size;
if (qkvi.scalar_type() == at::ScalarType::Half) {
using dtype = at::Half;
flash_attention::prepare_kernel_fwd<dtype>
<<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
qkvi.data_ptr<dtype>(), qkv.data_ptr<dtype>(), shape[1], shape[2], shape[3], shape[4]);
} else {
using dtype = at::BFloat16;
flash_attention::prepare_kernel_fwd<dtype>
<<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
qkvi.data_ptr<dtype>(), qkv.data_ptr<dtype>(), shape[1], shape[2], shape[3], shape[4]);
}
auto te_qkvi = makeTransformerEngineTensor(qkvi);
auto te_qkv = makeTransformerEngineTensor(qkv);
nvte_prepare_flash_attn_fwd(te_qkvi.data(), te_qkv.data(), at::cuda::getCurrentCUDAStream());
return qkv;
}
at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
NVTE_CHECK(q.is_contiguous());
NVTE_CHECK(k.is_contiguous());
NVTE_CHECK(v.is_contiguous());
......@@ -618,36 +522,18 @@ at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v) {
q.scalar_type() == at::ScalarType::BFloat16);
NVTE_CHECK(k.scalar_type() == q.scalar_type());
NVTE_CHECK(v.scalar_type() == q.scalar_type());
NVTE_CHECK(q.size(3) % flash_attention::load_size == 0);
NVTE_CHECK(q.size(3) == flash_attention::load_size);
NVTE_CHECK(k.size(3) % flash_attention::load_size == 0);
NVTE_CHECK(k.size(3) == flash_attention::load_size);
NVTE_CHECK(v.size(3) % flash_attention::load_size == 0);
NVTE_CHECK(v.size(3) == flash_attention::load_size);
// 3 x [s, b, n, h] -> [b, s, n, 3 * h]
std::vector<int64_t> shape = {q.size(1), q.size(0), q.size(2), 3 * q.size(3)};
at::Tensor qkv = at::empty(shape, at::CUDA(q.scalar_type()));
size_t warps = q.size(0) * q.size(1);
size_t warps_per_block = flash_attention::block_size / flash_attention::warp_size;
size_t blocks = (warps + warps_per_block - 1) / warps_per_block;
dim3 grid(blocks, 3);
int threads = flash_attention::block_size;
if (q.scalar_type() == at::ScalarType::Half) {
using dtype = at::Half;
flash_attention::prepare_kernel_bwd<dtype>
<<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
q.data_ptr<dtype>(), k.data_ptr<dtype>(), v.data_ptr<dtype>(), qkv.data_ptr<dtype>(),
q.size(0), q.size(1), q.size(2), q.size(3));
} else {
using dtype = at::BFloat16;
flash_attention::prepare_kernel_bwd<dtype>
<<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
q.data_ptr<dtype>(), k.data_ptr<dtype>(), v.data_ptr<dtype>(), qkv.data_ptr<dtype>(),
q.size(0), q.size(1), q.size(2), q.size(3));
}
auto te_q = makeTransformerEngineTensor(q);
auto te_k = makeTransformerEngineTensor(k);
auto te_v = makeTransformerEngineTensor(v);
auto te_qkv = makeTransformerEngineTensor(qkv);
nvte_prepare_flash_attn_bwd(te_q.data(), te_k.data(), te_v.data(), te_qkv.data(),
at::cuda::getCurrentCUDAStream());
return qkv;
}
......@@ -658,6 +544,9 @@ at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v) {
at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_seqlens,
int half_idx) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
NVTE_CHECK(tensor.dim() == 3 || tensor.dim() == 4);
NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int);
NVTE_CHECK(cu_seqlens.dim() == 1);
......@@ -683,18 +572,12 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_s
shape[seq_dim] /= 2;
at::Tensor half = at::empty(shape, at::CUDA(tensor.scalar_type()));
// Launch Kernel
constexpr unsigned int block = 256;
unsigned int grid_x = (tensor.size(seq_dim) / 2 * 32 + block - 1) / block;
unsigned int grid_y = 1;
for (int i = 0; i < seq_dim; i++) {
grid_y *= tensor.size(i);
}
dim3 grid = {grid_x, grid_y};
transformer_engine::fused_attn::thd_read_half_tensor_kernel<<<
grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
half.data_ptr(), tensor.data_ptr(), cu_seqlens.data_ptr<int>(), batch, hidden_size_in_bytes,
half_idx, tensor.size(seq_dim));
auto te_tensor = makeTransformerEngineTensor(tensor);
auto te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens);
auto te_half = makeTransformerEngineTensor(half);
nvte_cp_thd_read_half_tensor(te_tensor.data(), te_cu_seqlens.data(), te_half.data(), half_idx,
at::cuda::getCurrentCUDAStream());
return half;
}
......@@ -705,6 +588,9 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_s
void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_step,
const at::Tensor &cu_seqlens, bool lse_packed) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
NVTE_CHECK(lse.scalar_type() == at::ScalarType::Float);
NVTE_CHECK(lse_per_step.scalar_type() == at::ScalarType::Float);
NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int);
......@@ -738,26 +624,20 @@ void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_st
NVTE_CHECK(cu_seqlens.size(0) == batch + 1);
}
constexpr unsigned int block = 256;
unsigned int grid_x = (lse_seqlen / 2 + block - 1) / block;
unsigned int grid_y = num_heads;
dim3 grid = {grid_x, grid_y};
auto te_lse = makeTransformerEngineTensor(lse);
auto te_lse_per_step = makeTransformerEngineTensor(lse_per_step);
auto te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens);
if (lse_packed) {
transformer_engine::fused_attn::thd_lse_kernel<true, LseCorrectionFunctor>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
lse.data_ptr<float>(), lse_per_step.data_ptr<float>(), cu_seqlens.data_ptr<int>(),
batch, num_heads, lse_seqlen, second_half_lse_seqlen);
} else {
transformer_engine::fused_attn::thd_lse_kernel<false, LseCorrectionFunctor>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
lse.data_ptr<float>(), lse_per_step.data_ptr<float>(), cu_seqlens.data_ptr<int>(),
batch, num_heads, lse_seqlen, second_half_lse_seqlen);
}
nvte_cp_thd_second_half_lse_correction(te_lse.data(), te_lse_per_step.data(),
te_cu_seqlens.data(), lse_packed,
at::cuda::getCurrentCUDAStream());
}
at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_seqlens,
bool lse_packed, int second_half_lse_seqlen) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
NVTE_CHECK(lse.scalar_type() == at::ScalarType::Float);
NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int);
NVTE_CHECK(cu_seqlens.dim() == 1);
......@@ -790,22 +670,13 @@ at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_
at::Tensor half_lse = at::zeros(shape, at::CUDA(lse.scalar_type()));
constexpr unsigned int block = 256;
unsigned int grid_x = (lse_seqlen / 2 + block - 1) / block;
unsigned int grid_y = num_heads;
dim3 grid = {grid_x, grid_y};
auto te_lse = makeTransformerEngineTensor(lse);
auto te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens);
auto te_half_lse = makeTransformerEngineTensor(half_lse);
if (lse_packed) {
transformer_engine::fused_attn::thd_lse_kernel<true, ReadLseFunctor>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
lse.data_ptr<float>(), half_lse.data_ptr<float>(), cu_seqlens.data_ptr<int>(), batch,
num_heads, lse_seqlen, second_half_lse_seqlen);
} else {
transformer_engine::fused_attn::thd_lse_kernel<false, ReadLseFunctor>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
lse.data_ptr<float>(), half_lse.data_ptr<float>(), cu_seqlens.data_ptr<int>(), batch,
num_heads, lse_seqlen, second_half_lse_seqlen);
}
nvte_cp_thd_read_second_half_lse(te_lse.data(), te_cu_seqlens.data(), te_half_lse.data(),
lse_packed, second_half_lse_seqlen,
at::cuda::getCurrentCUDAStream());
return half_lse;
}
......@@ -814,194 +685,38 @@ at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_
* Support THD format for Context Parallel: Out correction in forward
**************************************************************************************************/
template <typename dtype, int only_second_half>
static void thd_out_correction_helper(at::Tensor out, const at::Tensor &out_per_step,
const at::Tensor &lse, const at::Tensor &lse_per_step,
const at::Tensor &cu_seqlens, bool lse_packed) {
NVTE_CHECK(out.scalar_type() == out_per_step.scalar_type());
NVTE_CHECK(lse.scalar_type() == at::ScalarType::Float);
NVTE_CHECK(lse_per_step.scalar_type() == at::ScalarType::Float);
NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int);
int total_tokens = out.size(0);
int num_heads = out.size(1);
int dim_per_head = out.size(2);
NVTE_CHECK(out_per_step.size(0) == total_tokens / (only_second_half + 1));
NVTE_CHECK(out_per_step.size(1) == num_heads);
NVTE_CHECK(out_per_step.size(2) == dim_per_head);
int batch, lse_seqlen, lse_per_step_seqlen;
if (lse_packed) {
batch = cu_seqlens.size(0) - 1;
lse_seqlen = lse.size(1);
lse_per_step_seqlen = lse_per_step.size(1);
NVTE_CHECK(lse.size(0) == num_heads);
NVTE_CHECK(lse_seqlen >= total_tokens);
NVTE_CHECK(lse_per_step.size(0) == num_heads);
NVTE_CHECK(lse_per_step_seqlen >= lse_seqlen / (only_second_half + 1));
} else {
batch = lse.size(0);
lse_seqlen = lse.size(2);
lse_per_step_seqlen = lse_per_step.size(2);
NVTE_CHECK(lse.size(1) == num_heads);
NVTE_CHECK(lse_per_step.size(0) == batch);
NVTE_CHECK(lse_per_step.size(1) == num_heads);
NVTE_CHECK(lse_per_step_seqlen == lse_seqlen / (only_second_half + 1));
NVTE_CHECK(cu_seqlens.size(0) == batch + 1);
}
constexpr int tile = 16;
constexpr int block = 512;
unsigned int grid_x =
(static_cast<size_t>(total_tokens) / (only_second_half + 1) * tile + block - 1) / block;
dim3 grid = {grid_x, (unsigned int)num_heads};
if (lse_packed) {
transformer_engine::fused_attn::thd_out_correction_kernel<dtype, only_second_half, tile, true>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
out.data_ptr<dtype>(), out_per_step.data_ptr<dtype>(), lse.data_ptr<float>(),
lse_per_step.data_ptr<float>(), cu_seqlens.data_ptr<int>(), batch, num_heads,
dim_per_head, lse_seqlen, lse_per_step_seqlen);
} else {
transformer_engine::fused_attn::thd_out_correction_kernel<dtype, only_second_half, tile, false>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
out.data_ptr<dtype>(), out_per_step.data_ptr<dtype>(), lse.data_ptr<float>(),
lse_per_step.data_ptr<float>(), cu_seqlens.data_ptr<int>(), batch, num_heads,
dim_per_head, lse_seqlen, lse_per_step_seqlen);
}
}
void thd_out_correction(at::Tensor out, const at::Tensor &out_per_step, const at::Tensor &lse,
const at::Tensor &lse_per_step, const at::Tensor &cu_seqlens,
bool only_second_half, bool lse_packed) {
if (only_second_half) {
if (out.scalar_type() == at::ScalarType::Half) {
using dtype = at::Half;
thd_out_correction_helper<dtype, 1>(out, out_per_step, lse, lse_per_step, cu_seqlens,
lse_packed);
} else if (out.scalar_type() == at::ScalarType::BFloat16) {
using dtype = at::BFloat16;
thd_out_correction_helper<dtype, 1>(out, out_per_step, lse, lse_per_step, cu_seqlens,
lse_packed);
} else if (out.scalar_type() == at::ScalarType::Float) {
using dtype = float;
thd_out_correction_helper<dtype, 1>(out, out_per_step, lse, lse_per_step, cu_seqlens,
lse_packed);
} else {
NVTE_ERROR("Unsupported dtype of out\n");
}
} else {
if (out.scalar_type() == at::ScalarType::Half) {
using dtype = at::Half;
thd_out_correction_helper<dtype, 0>(out, out_per_step, lse, lse_per_step, cu_seqlens,
lse_packed);
} else if (out.scalar_type() == at::ScalarType::BFloat16) {
using dtype = at::BFloat16;
thd_out_correction_helper<dtype, 0>(out, out_per_step, lse, lse_per_step, cu_seqlens,
lse_packed);
} else if (out.scalar_type() == at::ScalarType::Float) {
using dtype = float;
thd_out_correction_helper<dtype, 0>(out, out_per_step, lse, lse_per_step, cu_seqlens,
lse_packed);
} else {
NVTE_ERROR("Unsupported dtype of out\n");
}
}
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
auto te_out = makeTransformerEngineTensor(out);
auto te_out_per_step = makeTransformerEngineTensor(out_per_step);
auto te_lse = makeTransformerEngineTensor(lse);
auto te_lse_per_step = makeTransformerEngineTensor(lse_per_step);
auto te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens);
nvte_cp_thd_out_correction(te_out.data(), te_out_per_step.data(), te_lse.data(),
te_lse_per_step.data(), te_cu_seqlens.data(), only_second_half,
lse_packed, at::cuda::getCurrentCUDAStream());
}
/***************************************************************************************************
* Support THD format for Context Parallel: Gradients correction in backward
**************************************************************************************************/
template <typename dtype, typename Functor_0, typename Functor_1, int functor_idx>
static void thd_grad_correction_helper(at::Tensor grad, const at::Tensor &grad_per_step,
const at::Tensor &cu_seqlens) {
NVTE_CHECK(grad.dim() == 3 || grad.dim() == 4);
NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int);
NVTE_CHECK(cu_seqlens.dim() == 1);
// Shape of dq is [t, h, d], so the dimension of "t" is 0
// Shape of dkv is [2, t, h, d], so the dimension of "t" is 1
int seq_dim = grad.dim() == 3 ? 0 : 1;
int total_tokens = grad.size(seq_dim);
int num_heads = grad.size(seq_dim + 1);
int dim_per_head = grad.size(seq_dim + 2);
int batch = cu_seqlens.size(0) - 1;
if constexpr (functor_idx < 2) {
NVTE_CHECK(grad_per_step.size(seq_dim) == total_tokens / 2);
} else {
NVTE_CHECK(grad_per_step.size(seq_dim) == total_tokens);
}
NVTE_CHECK(grad_per_step.size(seq_dim + 1) == num_heads);
NVTE_CHECK(grad_per_step.size(seq_dim + 2) == dim_per_head);
size_t hidden_size = num_heads * dim_per_head;
NVTE_CHECK((hidden_size * c10::elementSize(grad.scalar_type())) % 16 == 0);
constexpr unsigned int block = 256;
unsigned int grid_x;
if constexpr (functor_idx < 2) {
grid_x = (total_tokens / 2 * 32 + block - 1) / block;
} else {
grid_x = (total_tokens * 32 + block - 1) / block;
}
unsigned int grid_y = 1;
for (int i = 0; i < seq_dim; i++) {
grid_y *= grad.size(i);
}
dim3 grid = {grid_x, grid_y};
transformer_engine::fused_attn::thd_grad_correction_kernel<dtype, Functor_0, Functor_1,
functor_idx, 32>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
grad.data_ptr<dtype>(), grad_per_step.data_ptr<dtype>(), cu_seqlens.data_ptr<int>(),
batch, hidden_size, total_tokens);
}
template <typename dtype>
static void thd_grad_dispatcher(at::Tensor grad, const at::Tensor &grad_per_step,
const at::Tensor &cu_seqlens, const std::string &first_half,
const std::string &second_half) {
if (first_half == "add" && second_half == "none") {
thd_grad_correction_helper<dtype, AddFunctor<dtype>, EmptyFunctor, 0>(grad, grad_per_step,
cu_seqlens);
} else if (first_half == "copy" && second_half == "none") {
thd_grad_correction_helper<dtype, CopyFunctor, EmptyFunctor, 0>(grad, grad_per_step,
cu_seqlens);
} else if (first_half == "none" && second_half == "add") {
thd_grad_correction_helper<dtype, EmptyFunctor, AddFunctor<dtype>, 1>(grad, grad_per_step,
cu_seqlens);
} else if (first_half == "none" && second_half == "copy") {
thd_grad_correction_helper<dtype, EmptyFunctor, CopyFunctor, 1>(grad, grad_per_step,
cu_seqlens);
} else if (first_half == "add" && second_half == "copy") {
thd_grad_correction_helper<dtype, AddFunctor<dtype>, CopyFunctor, 2>(grad, grad_per_step,
cu_seqlens);
} else if (first_half == "copy" && second_half == "add") {
thd_grad_correction_helper<dtype, CopyFunctor, AddFunctor<dtype>, 2>(grad, grad_per_step,
cu_seqlens);
} else {
NVTE_ERROR("Unsupported Functor of first half and second_half\n");
}
}
void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step,
const at::Tensor &cu_seqlens, const std::string &first_half,
const std::string &second_half) {
if (grad.scalar_type() == at::ScalarType::Half) {
thd_grad_dispatcher<at::Half>(grad, grad_per_step, cu_seqlens, first_half, second_half);
} else if (grad.scalar_type() == at::ScalarType::BFloat16) {
thd_grad_dispatcher<at::BFloat16>(grad, grad_per_step, cu_seqlens, first_half, second_half);
} else if (grad.scalar_type() == at::ScalarType::Float) {
thd_grad_dispatcher<float>(grad, grad_per_step, cu_seqlens, first_half, second_half);
} else {
NVTE_ERROR("Unsupported dtype of grad\n");
}
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
auto te_grad = makeTransformerEngineTensor(grad);
auto te_grad_per_step = makeTransformerEngineTensor(grad_per_step);
auto te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens);
nvte_cp_thd_grad_correction(te_grad.data(), te_grad_per_step.data(), te_cu_seqlens.data(),
first_half.data(), second_half.data(),
at::cuda::getCurrentCUDAStream());
}
/***************************************************************************************************
......@@ -1010,6 +725,9 @@ void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step,
at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_tokens,
int world_size, int rank) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int);
NVTE_CHECK(cu_seqlens.dim() == 1);
NVTE_CHECK(cu_seqlens.size(0) >= 2);
......@@ -1022,11 +740,11 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t
std::vector<int64_t> shape = {total_tokens / world_size};
at::Tensor output = at::empty(shape, at::CUDA(at::ScalarType::Int));
constexpr unsigned int block = 256;
unsigned int grid = (output.size(0) + block - 1) / block;
transformer_engine::fused_attn::thd_partition_indices_kernel<<<
grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
output.data_ptr<int>(), cu_seqlens.data_ptr<int>(), batch, total_tokens, world_size, rank);
auto te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens);
auto te_output = makeTransformerEngineTensor(output);
nvte_cp_thd_get_partitioned_indices(te_cu_seqlens.data(), te_output.data(), total_tokens,
world_size, rank, at::cuda::getCurrentCUDAStream());
return output;
}
......@@ -1035,39 +753,22 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t
* KV Cache: Convert a tensor from qkv_format = thd to qkv_format = bshd
**************************************************************************************************/
template <typename scalar_t>
void convert_thd_to_bshd_launcher(at::Tensor tensor, at::Tensor new_tensor, at::Tensor cu_seqlens,
int b, int max_seq_len, int h, int d) {
transformer_engine::fused_attn::
convert_thd_to_bshd_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<scalar_t *>(tensor.data_ptr<scalar_t>()),
reinterpret_cast<scalar_t *>(new_tensor.data_ptr<scalar_t>()), cu_seqlens.data_ptr<int>(),
b, max_seq_len, h, d);
}
at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
int h = tensor.size(1);
int d = tensor.size(2);
std::vector<int64_t> shape = {b, max_seq_len, h, d};
at::Tensor new_tensor = at::zeros(shape, at::CUDA(tensor.scalar_type()));
if (new_tensor.scalar_type() == at::ScalarType::Half) {
using dtype = at::Half;
convert_thd_to_bshd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d);
} else if (new_tensor.scalar_type() == at::ScalarType::BFloat16) {
using dtype = at::BFloat16;
convert_thd_to_bshd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d);
} else if (new_tensor.scalar_type() == at::ScalarType::Float) {
using dtype = float;
convert_thd_to_bshd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d);
} else if (new_tensor.scalar_type() == at::ScalarType::Float8_e4m3fn) {
using dtype = at::Float8_e4m3fn;
convert_thd_to_bshd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d);
} else if (new_tensor.scalar_type() == at::ScalarType::Float8_e5m2) {
using dtype = at::Float8_e5m2;
convert_thd_to_bshd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d);
} else {
NVTE_ERROR("Unsupported dtype for KV cache.\n");
}
auto te_tensor = makeTransformerEngineTensor(tensor);
auto te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens);
auto te_new_tensor = makeTransformerEngineTensor(new_tensor);
nvte_convert_thd_to_bshd(te_tensor.data(), te_cu_seqlens.data(), te_new_tensor.data(), b,
max_seq_len, at::cuda::getCurrentCUDAStream());
return new_tensor;
}
......@@ -1075,95 +776,33 @@ at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b,
* KV Cache: Convert a tensor from qkv_format = bshd to qkv_format = thd
**************************************************************************************************/
template <typename scalar_t>
void convert_bshd_to_thd_launcher(at::Tensor tensor, at::Tensor new_tensor, at::Tensor cu_seqlens,
int b, int max_seq_len, int h, int d) {
transformer_engine::fused_attn::
convert_bshd_to_thd_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<scalar_t *>(tensor.data_ptr<scalar_t>()),
reinterpret_cast<scalar_t *>(new_tensor.data_ptr<scalar_t>()), cu_seqlens.data_ptr<int>(),
b, max_seq_len, h, d);
}
at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int t) {
int b = tensor.size(0);
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
int max_seq_len = tensor.size(1);
int h = tensor.size(2);
int d = tensor.size(3);
std::vector<int64_t> shape = {t, h, d};
at::Tensor new_tensor = at::zeros(shape, at::CUDA(tensor.scalar_type()));
if (tensor.scalar_type() == at::ScalarType::Half) {
using dtype = at::Half;
convert_bshd_to_thd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d);
} else if (tensor.scalar_type() == at::ScalarType::BFloat16) {
using dtype = at::BFloat16;
convert_bshd_to_thd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d);
} else if (tensor.scalar_type() == at::ScalarType::Float) {
using dtype = float;
convert_bshd_to_thd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d);
} else if (tensor.scalar_type() == at::ScalarType::Float8_e4m3fn) {
using dtype = at::Float8_e4m3fn;
convert_bshd_to_thd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d);
} else if (tensor.scalar_type() == at::ScalarType::Float8_e5m2) {
using dtype = at::Float8_e5m2;
convert_bshd_to_thd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d);
} else {
NVTE_ERROR("Unsupported dtype for KV cache.\n");
}
return new_tensor;
}
/***************************************************************************************************
* KV Cache: Copy new KV tokens to the KV cache
* 1. new_k and new_v are in qkv_format; k_cache and v_cache are in 'bshd' format
* 2. cu_new_lens and cu_cached_lens are in shape [b + 1]; cu_cached_lens include the added lens
* in current step
* 3. Non-paged KV cache is a special case of paged KV cache, with page_table = [b, 1] and
* max_pages_per_seq = 1. We use the same underlying kernel for both non-paged and paged.
* Set is_non_paged = True/False to indicate as such.
* 4. is_non_paged = True also re-indexes the KV cache, e.g. the initial batch indices [0, 3, 1, 2]
* becomes [0, 1, 1, 2]. The page_table = batch_indices.unsqueeze(1) is however unchanged.
* batch_indices_post can be used for monotonical indexing, i.e. [0, 1, 2, 3]. batch_indices is
* preserved for the next layer in the same iteration.
* 5. Only supports same page_table for k_cache and v_cache
* 6. Only pad_between_seqs = False when qkv_format = thd, i.e. there should be no pad tokens
* between sequences in new_k and new_v such as [a a a 0..0 b b 0..0 c 0..0].
**************************************************************************************************/
auto te_tensor = makeTransformerEngineTensor(tensor);
auto te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens);
auto te_new_tensor = makeTransformerEngineTensor(new_tensor);
template <typename scalar_t>
void copy_to_kv_cache_launcher(at::Tensor new_k, at::Tensor new_v, at::Tensor k_cache,
at::Tensor v_cache, at::Tensor page_table, at::Tensor cu_new_lens,
at::Tensor cu_cached_lens, NVTE_QKV_Format qkv_format, int h_kv,
int d_k, int d_v, int b, int max_ctx_len, int max_seq_len,
int max_pages_per_seq, bool is_non_paged) {
if (new_k.data_ptr() != nullptr && new_v.data_ptr() != nullptr && k_cache.data_ptr() != nullptr &&
v_cache.data_ptr() != nullptr) {
if (is_non_paged) {
transformer_engine::fused_attn::
reindex_kv_cache_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<scalar_t *>(k_cache.data_ptr<scalar_t>()),
reinterpret_cast<scalar_t *>(v_cache.data_ptr<scalar_t>()),
page_table.data_ptr<int>(), cu_new_lens.data_ptr<int>(),
cu_cached_lens.data_ptr<int>(), h_kv, d_k, d_v, b, max_seq_len);
}
transformer_engine::fused_attn::
copy_to_kv_cache_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<scalar_t *>(new_k.data_ptr<scalar_t>()),
reinterpret_cast<scalar_t *>(new_v.data_ptr<scalar_t>()),
reinterpret_cast<scalar_t *>(k_cache.data_ptr<scalar_t>()),
reinterpret_cast<scalar_t *>(v_cache.data_ptr<scalar_t>()), page_table.data_ptr<int>(),
cu_new_lens.data_ptr<int>(), cu_cached_lens.data_ptr<int>(), qkv_format, h_kv, d_k, d_v,
b, max_ctx_len, max_seq_len, max_pages_per_seq, is_non_paged);
}
nvte_convert_bshd_to_thd(te_tensor.data(), te_cu_seqlens.data(), te_new_tensor.data(), t,
at::cuda::getCurrentCUDAStream());
return new_tensor;
}
void copy_to_kv_cache(at::Tensor new_k, at::Tensor new_v, at::Tensor k_cache, at::Tensor v_cache,
at::Tensor page_table, at::Tensor cu_new_lens, at::Tensor cu_cached_lens,
NVTE_QKV_Format qkv_format, int b, int max_ctx_len, int max_seq_len,
int max_pages_per_seq, bool is_non_paged) {
int h_kv = new_k.size(-2);
int d_k = new_k.size(-1);
int d_v = new_v.size(-1);
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
NVTE_CHECK(k_cache.scalar_type() == v_cache.scalar_type() &&
new_k.scalar_type() == new_v.scalar_type() &&
new_k.scalar_type() == k_cache.scalar_type(),
......@@ -1171,33 +810,17 @@ void copy_to_kv_cache(at::Tensor new_k, at::Tensor new_v, at::Tensor k_cache, at
NVTE_CHECK(qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD ||
qkv_format == NVTE_QKV_Format::NVTE_THD,
"qkv_format must be {BSHD, SBHD, THD}.");
if (k_cache.scalar_type() == at::ScalarType::Half) {
using dtype = at::Half;
copy_to_kv_cache_launcher<dtype>(new_k, new_v, k_cache, v_cache, page_table, cu_new_lens,
cu_cached_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len,
max_seq_len, max_pages_per_seq, is_non_paged);
} else if (k_cache.scalar_type() == at::ScalarType::BFloat16) {
using dtype = at::BFloat16;
copy_to_kv_cache_launcher<dtype>(new_k, new_v, k_cache, v_cache, page_table, cu_new_lens,
cu_cached_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len,
max_seq_len, max_pages_per_seq, is_non_paged);
} else if (k_cache.scalar_type() == at::ScalarType::Float) {
using dtype = float;
copy_to_kv_cache_launcher<dtype>(new_k, new_v, k_cache, v_cache, page_table, cu_new_lens,
cu_cached_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len,
max_seq_len, max_pages_per_seq, is_non_paged);
} else if (k_cache.scalar_type() == at::ScalarType::Float8_e4m3fn) {
using dtype = at::Float8_e4m3fn;
copy_to_kv_cache_launcher<dtype>(new_k, new_v, k_cache, v_cache, page_table, cu_new_lens,
cu_cached_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len,
max_seq_len, max_pages_per_seq, is_non_paged);
} else if (k_cache.scalar_type() == at::ScalarType::Float8_e5m2) {
using dtype = at::Float8_e5m2;
copy_to_kv_cache_launcher<dtype>(new_k, new_v, k_cache, v_cache, page_table, cu_new_lens,
cu_cached_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len,
max_seq_len, max_pages_per_seq, is_non_paged);
} else {
NVTE_ERROR("Unsupported dtype for KV cache.\n");
}
auto te_new_k = makeTransformerEngineTensor(new_k);
auto te_new_v = makeTransformerEngineTensor(new_v);
auto te_k_cache = makeTransformerEngineTensor(k_cache);
auto te_v_cache = makeTransformerEngineTensor(v_cache);
auto te_page_table = makeTransformerEngineTensor(page_table);
auto te_cu_new_lens = makeTransformerEngineTensor(cu_new_lens);
auto te_cu_cached_lens = makeTransformerEngineTensor(cu_cached_lens);
nvte_copy_to_kv_cache(te_new_k.data(), te_new_v.data(), te_k_cache.data(), te_v_cache.data(),
te_page_table.data(), te_cu_new_lens.data(), te_cu_cached_lens.data(),
qkv_format, b, max_ctx_len, max_seq_len, max_pages_per_seq, is_non_paged,
at::cuda::getCurrentCUDAStream());
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment