"...git@developer.sourcefind.cn:kecinstone/2024-pra-vllm.git" did not exist on "24cde76a152fbffde30fa2be0d08dcbad490530e"
Unverified Commit 51cd4415 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[C][PyTorch]Make pytorch extensions pure cpp (#1754)



* First pass refactor
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* first pass
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* core compiles
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Include cuda dirs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Compiles
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Move grad outside autocast
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix kv cache
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Address review comments
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Change src file name in cmake
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* move the kernels too
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Move comment
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Move comments around
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* more movement
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* move
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent dfe1a65a
...@@ -8,11 +8,7 @@ from pathlib import Path ...@@ -8,11 +8,7 @@ from pathlib import Path
import setuptools import setuptools
from .utils import ( from .utils import all_files_in_dir, cuda_version, get_cuda_include_dirs
all_files_in_dir,
cuda_archs,
cuda_version,
)
def setup_pytorch_extension( def setup_pytorch_extension(
...@@ -30,55 +26,30 @@ def setup_pytorch_extension( ...@@ -30,55 +26,30 @@ def setup_pytorch_extension(
] + all_files_in_dir(extensions_dir) ] + all_files_in_dir(extensions_dir)
# Header files # Header files
include_dirs = [ include_dirs = get_cuda_include_dirs()
common_header_files, include_dirs.extend(
common_header_files / "common", [
common_header_files / "common" / "include", common_header_files,
csrc_header_files, common_header_files / "common",
] common_header_files / "common" / "include",
csrc_header_files,
]
)
# Compiler flags # Compiler flags
cxx_flags = [ cxx_flags = [
"-O3", "-O3",
"-fvisibility=hidden", "-fvisibility=hidden",
] ]
nvcc_flags = [
"-O3",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
]
cuda_architectures = cuda_archs()
if "70" in cuda_architectures:
nvcc_flags.extend(["-gencode", "arch=compute_70,code=sm_70"])
# Version-dependent CUDA options # Version-dependent CUDA options
try: try:
version = cuda_version() version = cuda_version()
except FileNotFoundError: except FileNotFoundError:
print("Could not determine CUDA Toolkit version") print("Could not determine CUDA version")
else: else:
if version < (12, 0): if version < (12, 0):
raise RuntimeError("Transformer Engine requires CUDA 12.0 or newer") raise RuntimeError("Transformer Engine requires CUDA 12.0 or newer")
nvcc_flags.extend(
(
"--threads",
os.getenv("NVTE_BUILD_THREADS_PER_JOB", "1"),
)
)
for arch in cuda_architectures.split(";"):
if arch == "70":
continue # Already handled
nvcc_flags.extend(["-gencode", f"arch=compute_{arch},code=sm_{arch}"])
if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))): if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
assert ( assert (
...@@ -87,7 +58,6 @@ def setup_pytorch_extension( ...@@ -87,7 +58,6 @@ def setup_pytorch_extension(
mpi_path = Path(os.getenv("MPI_HOME")) mpi_path = Path(os.getenv("MPI_HOME"))
include_dirs.append(mpi_path / "include") include_dirs.append(mpi_path / "include")
cxx_flags.append("-DNVTE_UB_WITH_MPI") cxx_flags.append("-DNVTE_UB_WITH_MPI")
nvcc_flags.append("-DNVTE_UB_WITH_MPI")
library_dirs = [] library_dirs = []
libraries = [] libraries = []
...@@ -100,21 +70,17 @@ def setup_pytorch_extension( ...@@ -100,21 +70,17 @@ def setup_pytorch_extension(
library_dirs.append(nvshmem_home / "lib") library_dirs.append(nvshmem_home / "lib")
libraries.append("nvshmem_host") libraries.append("nvshmem_host")
cxx_flags.append("-DNVTE_ENABLE_NVSHMEM") cxx_flags.append("-DNVTE_ENABLE_NVSHMEM")
nvcc_flags.append("-DNVTE_ENABLE_NVSHMEM")
# Construct PyTorch CUDA extension # Construct PyTorch CUDA extension
sources = [str(path) for path in sources] sources = [str(path) for path in sources]
include_dirs = [str(path) for path in include_dirs] include_dirs = [str(path) for path in include_dirs]
from torch.utils.cpp_extension import CUDAExtension from torch.utils.cpp_extension import CppExtension
return CUDAExtension( return CppExtension(
name="transformer_engine_torch", name="transformer_engine_torch",
sources=[str(src) for src in sources], sources=[str(src) for src in sources],
include_dirs=[str(inc) for inc in include_dirs], include_dirs=[str(inc) for inc in include_dirs],
extra_compile_args={ extra_compile_args={"cxx": cxx_flags},
"cxx": cxx_flags,
"nvcc": nvcc_flags,
},
libraries=[str(lib) for lib in libraries], libraries=[str(lib) for lib in libraries],
library_dirs=[str(lib_dir) for lib_dir in library_dirs], library_dirs=[str(lib_dir) for lib_dir in library_dirs],
) )
...@@ -1628,8 +1628,8 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoP ...@@ -1628,8 +1628,8 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoP
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_kv=cu_seqlens_kv,
) )
if is_training: if is_training:
out.backward(out_grad) out.backward(out_grad)
param_names = [] param_names = []
param_names.append("hidden_states.grad") param_names.append("hidden_states.grad")
...@@ -1879,8 +1879,8 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training): ...@@ -1879,8 +1879,8 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training):
checkpoint_core_attention=False, checkpoint_core_attention=False,
core_attention_bias_type=config.attn_bias_type, core_attention_bias_type=config.attn_bias_type,
) )
if is_training: if is_training:
out.backward(out_grad) out.backward(out_grad)
if is_training: if is_training:
return out, (inp[0].grad, inp[1].grad, inp[2].grad) return out, (inp[0].grad, inp[1].grad, inp[2].grad)
...@@ -1993,7 +1993,7 @@ def _run_custom_mha_fp8(dtype, config, backend): ...@@ -1993,7 +1993,7 @@ def _run_custom_mha_fp8(dtype, config, backend):
mha = Custom_MHA_FP8(config).to(dtype=dtype, device="cuda") mha = Custom_MHA_FP8(config).to(dtype=dtype, device="cuda")
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
out = mha(inp, cu_seqlens, config.max_seqlen_q) out = mha(inp, cu_seqlens, config.max_seqlen_q)
out.backward(out_grad) out.backward(out_grad)
out = torch.load("out.pt") out = torch.load("out.pt")
dqkv = torch.load("dqkv.pt") dqkv = torch.load("dqkv.pt")
......
...@@ -130,18 +130,20 @@ def dtype_tols(dtype: torch.dtype) -> Dict[str, float]: ...@@ -130,18 +130,20 @@ def dtype_tols(dtype: torch.dtype) -> Dict[str, float]:
def assert_allclose( def assert_allclose(
l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float, rtol: float = None l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float = None, rtol: float = None
) -> bool: ) -> bool:
"""Ensures two lists are equal.""" """Ensures two lists are equal."""
assert len(l1) == len(l2), "Unequal number of outputs." assert len(l1) == len(l2), "Unequal number of outputs."
for i, (t1, t2) in enumerate(zip(l1, l2)): for i, (t1, t2) in enumerate(zip(l1, l2)):
tols = dict(atol=atol) tols = dtype_tols(t2.dtype)
if rtol is not None: if rtol is not None:
tols["rtol"] = rtol tols["rtol"] = rtol
if atol is not None:
tols["atol"] = atol
result = torch.allclose(t1, t2, **tols) result = torch.allclose(t1, t2, **tols)
if not result: if not result:
diff = torch.abs(t1 - t2) diff = torch.abs(t1 - t2)
tol = atol + (rtol * torch.abs(t2)) tol = tols["atol"] + (tols["rtol"] * torch.abs(t2))
exceed_mask = diff > tol exceed_mask = diff > tol
if exceed_mask.any(): if exceed_mask.any():
indices = torch.nonzero(exceed_mask, as_tuple=True) indices = torch.nonzero(exceed_mask, as_tuple=True)
......
...@@ -66,6 +66,9 @@ list(APPEND transformer_engine_SOURCES ...@@ -66,6 +66,9 @@ list(APPEND transformer_engine_SOURCES
transpose/quantize_transpose_square_blockwise.cu transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise.cu transpose/quantize_transpose_vector_blockwise.cu
activation/gelu.cu activation/gelu.cu
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
fused_attn/kv_cache.cu
fused_attn/fused_attn_f16_max512_seqlen.cu fused_attn/fused_attn_f16_max512_seqlen.cu
fused_attn/fused_attn_f16_arbitrary_seqlen.cu fused_attn/fused_attn_f16_arbitrary_seqlen.cu
activation/relu.cu activation/relu.cu
...@@ -173,6 +176,9 @@ set_source_files_properties(fused_softmax/scaled_masked_softmax.cu ...@@ -173,6 +176,9 @@ set_source_files_properties(fused_softmax/scaled_masked_softmax.cu
multi_tensor/l2norm.cu multi_tensor/l2norm.cu
multi_tensor/scale.cu multi_tensor/scale.cu
multi_tensor/sgd.cu multi_tensor/sgd.cu
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
fused_attn/kv_cache.cu
PROPERTIES PROPERTIES
COMPILE_OPTIONS "--use_fast_math") COMPILE_OPTIONS "--use_fast_math")
option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --use_fast_math option" OFF) option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --use_fast_math option" OFF)
......
...@@ -111,7 +111,7 @@ struct Tensor { ...@@ -111,7 +111,7 @@ struct Tensor {
columnwise_scale_inv(nullptr, {1}, DType::kFloat32), columnwise_scale_inv(nullptr, {1}, DType::kFloat32),
scaling_mode(NVTE_DELAYED_TENSOR_SCALING) {} scaling_mode(NVTE_DELAYED_TENSOR_SCALING) {}
int numel() const { size_t numel() const {
size_t acc = 1; size_t acc = 1;
for (const auto dim : shape()) { for (const auto dim : shape()) {
acc *= dim; acc *= dim;
...@@ -133,6 +133,14 @@ struct Tensor { ...@@ -133,6 +133,14 @@ struct Tensor {
return data.dtype; return data.dtype;
} }
size_t dim() const {
if (!has_data() && has_columnwise_data()) {
return columnwise_data.shape.size();
} else {
return data.shape.size();
}
}
std::vector<size_t> shape() const { std::vector<size_t> shape() const {
/* Note: We sometimes experience spurious compiler errors /* Note: We sometimes experience spurious compiler errors
* (-Wstringop-overflow) from this function. It appears that GCC * (-Wstringop-overflow) from this function. It appears that GCC
...@@ -385,6 +393,33 @@ struct TypeInfo { ...@@ -385,6 +393,33 @@ struct TypeInfo {
NVTE_ERROR("Invalid type."); \ NVTE_ERROR("Invalid type."); \
} }
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FLOAT(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
case DType::kFloat32: { \
using type = float; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat16: { \
using type = fp16; \
{ __VA_ARGS__ } \
} break; \
case DType::kBFloat16: { \
using type = bf16; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat8E4M3: { \
using type = fp8e4m3; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat8E5M2: { \
using type = fp8e5m2; \
{ __VA_ARGS__ } \
} break; \
default: \
NVTE_ERROR("Invalid type."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(dtype, type, ...) \ #define TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(dtype, type, ...) \
switch (dtype) { \ switch (dtype) { \
using namespace transformer_engine; \ using namespace transformer_engine; \
......
...@@ -3,13 +3,17 @@ ...@@ -3,13 +3,17 @@
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_THD_UTILS_CUH_
#define TRANSFORMER_ENGINE_FUSED_ATTN_THD_UTILS_CUH_
#include <assert.h> #include <assert.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_bf16.h> #include <cuda_bf16.h>
#include "../common.h"
#include "transformer_engine/fused_attn.h"
namespace transformer_engine {
namespace context_parallel {
struct LseCorrectionFunctor { struct LseCorrectionFunctor {
__forceinline__ __device__ static void run(float *lse, float *half_lse, size_t idx, __forceinline__ __device__ static void run(float *lse, float *half_lse, size_t idx,
size_t half_idx) { size_t half_idx) {
...@@ -49,16 +53,13 @@ struct AddFunctor { ...@@ -49,16 +53,13 @@ struct AddFunctor {
#pragma unroll #pragma unroll
for (int i = 0; i < sizeof(float4) / sizeof(dtype); i++) { for (int i = 0; i < sizeof(float4) / sizeof(dtype); i++) {
p_[i] += p[i]; p_[i] = p_[i] + p[i];
} }
reinterpret_cast<float4 *>(token)[idx] = d_; reinterpret_cast<float4 *>(token)[idx] = d_;
} }
}; };
namespace transformer_engine {
namespace fused_attn {
/*************************************************************************************************** /***************************************************************************************************
* Support THD format for Context Parallel: Binary search an array for a target value * Support THD format for Context Parallel: Binary search an array for a target value
**************************************************************************************************/ **************************************************************************************************/
...@@ -107,6 +108,7 @@ __global__ void thd_partition_indices_kernel(int *output, int *cu_seqlens, int b ...@@ -107,6 +108,7 @@ __global__ void thd_partition_indices_kernel(int *output, int *cu_seqlens, int b
/*************************************************************************************************** /***************************************************************************************************
* Support THD format for Context Parallel: Read the half of a THD tensor * Support THD format for Context Parallel: Read the half of a THD tensor
**************************************************************************************************/ **************************************************************************************************/
__global__ void thd_read_half_tensor_kernel(void *half, void *tensor, int *cu_seqlens, int batch, __global__ void thd_read_half_tensor_kernel(void *half, void *tensor, int *cu_seqlens, int batch,
int hidden_size_in_bytes, int half_idx, int hidden_size_in_bytes, int half_idx,
int dim_size_of_token) { int dim_size_of_token) {
...@@ -232,7 +234,10 @@ __global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float ...@@ -232,7 +234,10 @@ __global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float
dtype *p_per_step = reinterpret_cast<dtype *>(&data_per_step); dtype *p_per_step = reinterpret_cast<dtype *>(&data_per_step);
dtype *p = reinterpret_cast<dtype *>(&data); dtype *p = reinterpret_cast<dtype *>(&data);
for (int k = 0; k < sizeof(float4) / sizeof(dtype); k++) { for (int k = 0; k < sizeof(float4) / sizeof(dtype); k++) {
p[k] += (p_per_step[k] == 0 ? 0 : p_per_step[k] * lse_corrected_exp); p[k] = p[k] +
(p_per_step[k] == static_cast<dtype>(0.f)
? static_cast<dtype>(0.f)
: static_cast<dtype>(static_cast<float>(p_per_step[k]) * lse_corrected_exp));
} }
reinterpret_cast<float4 *>(cur_out)[j] = data; reinterpret_cast<float4 *>(cur_out)[j] = data;
} }
...@@ -297,6 +302,442 @@ __global__ void thd_grad_correction_kernel(dtype *grad, dtype *grad_per_step, in ...@@ -297,6 +302,442 @@ __global__ void thd_grad_correction_kernel(dtype *grad, dtype *grad_per_step, in
} }
} }
} // namespace fused_attn /***************************************************************************************************
* Support THD format for Context Parallel: Read the half of a THD tensor
**************************************************************************************************/
void thd_read_half_tensor(const Tensor &tensor, const Tensor &cu_seqlens, Tensor &half,
int half_idx, cudaStream_t stream) {
using namespace transformer_engine;
NVTE_CHECK(tensor.dim() == 3 || tensor.dim() == 4);
NVTE_CHECK(cu_seqlens.dtype() == DType::kInt32);
auto cu_seqlens_shape = cu_seqlens.shape();
auto tensor_shape = tensor.shape();
NVTE_CHECK(cu_seqlens.dim() == 1);
NVTE_CHECK(cu_seqlens_shape[0] >= 2);
// Shapes of q and dq are [t, h, d], so the dimension of "t" is 0
// Shapes of kv and dkv are [2, t, h, d], so the dimension of "t" is 1
int seq_dim = tensor.dim() == 3 ? 0 : 1;
int batch = cu_seqlens_shape[0] - 1;
int num_heads = tensor_shape[seq_dim + 1];
int dim_per_head = tensor_shape[seq_dim + 2];
int hidden_size_in_bytes = num_heads * dim_per_head * typeToSize(tensor.dtype());
// For 128-bits load/store
NVTE_CHECK(hidden_size_in_bytes % 16 == 0);
// Launch Kernel
constexpr unsigned int block = 256;
unsigned int grid_x = (tensor_shape[seq_dim] / 2 * 32 + block - 1) / block;
unsigned int grid_y = 1;
for (int i = 0; i < seq_dim; i++) {
grid_y *= tensor_shape[i];
}
dim3 grid = {grid_x, grid_y};
thd_read_half_tensor_kernel<<<grid, block, sizeof(int) * (batch + 1), stream>>>(
half.data.dptr, tensor.data.dptr, reinterpret_cast<int *>(cu_seqlens.data.dptr), batch,
hidden_size_in_bytes, half_idx, tensor_shape[seq_dim]);
}
/***************************************************************************************************
* Support THD format for Context Parallel: softmax_lse related operations
**************************************************************************************************/
void thd_second_half_lse_correction(Tensor lse, const Tensor &lse_per_step,
const Tensor &cu_seqlens, bool lse_packed,
cudaStream_t stream) {
using namespace transformer_engine;
NVTE_CHECK(lse.dtype() == DType::kFloat32);
NVTE_CHECK(lse_per_step.dtype() == DType::kFloat32);
NVTE_CHECK(cu_seqlens.dtype() == DType::kInt32);
NVTE_CHECK(cu_seqlens.dim() == 1);
int batch, num_heads, lse_seqlen, second_half_lse_seqlen;
auto cu_seqlens_shape = cu_seqlens.shape();
auto lse_shape = lse.shape();
auto lse_per_step_shape = lse_per_step.shape();
if (lse_packed) {
NVTE_CHECK(lse.dim() == 2);
NVTE_CHECK(lse_per_step.dim() == 2);
batch = cu_seqlens_shape[0] - 1;
num_heads = lse_shape[0];
lse_seqlen = lse_shape[1];
second_half_lse_seqlen = lse_per_step_shape[1];
NVTE_CHECK(lse_per_step_shape[0] == num_heads);
NVTE_CHECK(second_half_lse_seqlen >= lse_seqlen / 2);
} else {
NVTE_CHECK(lse.dim() == 3);
NVTE_CHECK(lse_per_step.dim() == 3);
batch = lse_shape[0];
num_heads = lse_shape[1];
lse_seqlen = lse_shape[2];
second_half_lse_seqlen = lse_per_step_shape[2];
NVTE_CHECK(lse_per_step_shape[0] == batch);
NVTE_CHECK(lse_per_step_shape[1] == num_heads);
NVTE_CHECK(second_half_lse_seqlen == lse_seqlen / 2);
NVTE_CHECK(cu_seqlens_shape[0] == batch + 1);
}
constexpr unsigned int block = 256;
unsigned int grid_x = (lse_seqlen / 2 + block - 1) / block;
unsigned int grid_y = num_heads;
dim3 grid = {grid_x, grid_y};
if (lse_packed) {
thd_lse_kernel<true, LseCorrectionFunctor><<<grid, block, sizeof(int) * (batch + 1), stream>>>(
reinterpret_cast<float *>(lse.data.dptr), reinterpret_cast<float *>(lse_per_step.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, num_heads, lse_seqlen,
second_half_lse_seqlen);
} else {
thd_lse_kernel<false, LseCorrectionFunctor><<<grid, block, sizeof(int) * (batch + 1), stream>>>(
reinterpret_cast<float *>(lse.data.dptr), reinterpret_cast<float *>(lse_per_step.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, num_heads, lse_seqlen,
second_half_lse_seqlen);
}
}
void thd_read_second_half_lse(const Tensor &lse, const Tensor &cu_seqlens, Tensor &half_lse,
bool lse_packed, int second_half_lse_seqlen, cudaStream_t stream) {
using namespace transformer_engine;
NVTE_CHECK(lse.dtype() == DType::kFloat32);
NVTE_CHECK(cu_seqlens.dtype() == DType::kInt32);
NVTE_CHECK(cu_seqlens.dim() == 1);
int batch, num_heads, lse_seqlen;
auto cu_seqlens_shape = cu_seqlens.shape();
auto lse_shape = lse.shape();
if (lse_packed) {
NVTE_CHECK(lse.dim() == 2);
batch = cu_seqlens_shape[0] - 1;
num_heads = lse_shape[0];
lse_seqlen = lse_shape[1];
NVTE_CHECK(second_half_lse_seqlen >= lse_seqlen / 2);
} else {
NVTE_CHECK(lse.dim() == 3);
batch = lse_shape[0];
num_heads = lse_shape[1];
lse_seqlen = lse_shape[2];
NVTE_CHECK(cu_seqlens_shape[0] == batch + 1);
NVTE_CHECK(second_half_lse_seqlen == lse_seqlen / 2);
}
constexpr unsigned int block = 256;
unsigned int grid_x = (lse_seqlen / 2 + block - 1) / block;
unsigned int grid_y = num_heads;
dim3 grid = {grid_x, grid_y};
if (lse_packed) {
thd_lse_kernel<true, ReadLseFunctor><<<grid, block, sizeof(int) * (batch + 1), stream>>>(
reinterpret_cast<float *>(lse.data.dptr), reinterpret_cast<float *>(half_lse.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, num_heads, lse_seqlen,
second_half_lse_seqlen);
} else {
thd_lse_kernel<false, ReadLseFunctor><<<grid, block, sizeof(int) * (batch + 1), stream>>>(
reinterpret_cast<float *>(lse.data.dptr), reinterpret_cast<float *>(half_lse.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, num_heads, lse_seqlen,
second_half_lse_seqlen);
}
}
/***************************************************************************************************
* Support THD format for Context Parallel: Out correction in forward
**************************************************************************************************/
template <typename dtype, int only_second_half>
static void thd_out_correction_helper(Tensor out, const Tensor &out_per_step, const Tensor &lse,
const Tensor &lse_per_step, const Tensor &cu_seqlens,
bool lse_packed, cudaStream_t stream) {
using namespace transformer_engine;
NVTE_CHECK(out.dtype() == out_per_step.dtype());
NVTE_CHECK(lse.dtype() == DType::kFloat32);
NVTE_CHECK(lse_per_step.dtype() == DType::kFloat32);
NVTE_CHECK(cu_seqlens.dtype() == DType::kInt32);
auto out_shape = out.shape();
auto lse_shape = lse.shape();
auto out_per_step_shape = out_per_step.shape();
auto lse_per_step_shape = lse_per_step.shape();
auto cu_seqlens_shape = cu_seqlens.shape();
int total_tokens = out_shape[0];
int num_heads = out_shape[1];
int dim_per_head = out_shape[2];
NVTE_CHECK(out_per_step_shape[0] == total_tokens / (only_second_half + 1));
NVTE_CHECK(out_per_step_shape[1] == num_heads);
NVTE_CHECK(out_per_step_shape[2] == dim_per_head);
int batch, lse_seqlen, lse_per_step_seqlen;
if (lse_packed) {
batch = cu_seqlens_shape[0] - 1;
lse_seqlen = lse_shape[1];
lse_per_step_seqlen = lse_per_step_shape[1];
NVTE_CHECK(lse_shape[0] == num_heads);
NVTE_CHECK(lse_seqlen >= total_tokens);
NVTE_CHECK(lse_per_step_shape[0] == num_heads);
NVTE_CHECK(lse_per_step_seqlen >= lse_seqlen / (only_second_half + 1));
} else {
batch = lse_shape[0];
lse_seqlen = lse_shape[2];
lse_per_step_seqlen = lse_per_step_shape[2];
NVTE_CHECK(lse_shape[1] == num_heads);
NVTE_CHECK(lse_per_step_shape[0] == batch);
NVTE_CHECK(lse_per_step_shape[1] == num_heads);
NVTE_CHECK(lse_per_step_seqlen == lse_seqlen / (only_second_half + 1));
NVTE_CHECK(cu_seqlens_shape[0] == batch + 1);
}
constexpr int tile = 16;
constexpr int block = 512;
unsigned int grid_x =
(static_cast<size_t>(total_tokens) / (only_second_half + 1) * tile + block - 1) / block;
dim3 grid = {grid_x, (unsigned int)num_heads};
if (lse_packed) {
thd_out_correction_kernel<dtype, only_second_half, tile, true>
<<<grid, block, sizeof(int) * (batch + 1), stream>>>(
reinterpret_cast<dtype *>(out.data.dptr),
reinterpret_cast<dtype *>(out_per_step.data.dptr),
reinterpret_cast<float *>(lse.data.dptr),
reinterpret_cast<float *>(lse_per_step.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, num_heads, dim_per_head,
lse_seqlen, lse_per_step_seqlen);
} else {
thd_out_correction_kernel<dtype, only_second_half, tile, false>
<<<grid, block, sizeof(int) * (batch + 1), stream>>>(
reinterpret_cast<dtype *>(out.data.dptr),
reinterpret_cast<dtype *>(out_per_step.data.dptr),
reinterpret_cast<float *>(lse.data.dptr),
reinterpret_cast<float *>(lse_per_step.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, num_heads, dim_per_head,
lse_seqlen, lse_per_step_seqlen);
}
}
void thd_out_correction(Tensor out, const Tensor &out_per_step, const Tensor &lse,
const Tensor &lse_per_step, const Tensor &cu_seqlens, bool only_second_half,
bool lse_packed, cudaStream_t stream) {
using namespace transformer_engine;
if (only_second_half) {
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
out.dtype(), dtype,
thd_out_correction_helper<dtype, 1>(out, out_per_step, lse, lse_per_step, cu_seqlens,
lse_packed, stream););
} else {
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
out.dtype(), dtype,
thd_out_correction_helper<dtype, 0>(out, out_per_step, lse, lse_per_step, cu_seqlens,
lse_packed, stream););
}
}
/***************************************************************************************************
* Support THD format for Context Parallel: Gradients correction in backward
**************************************************************************************************/
template <typename dtype, typename Functor_0, typename Functor_1, int functor_idx>
static void thd_grad_correction_helper(Tensor grad, const Tensor &grad_per_step,
const Tensor &cu_seqlens, cudaStream_t stream) {
using namespace transformer_engine;
NVTE_CHECK(grad.dim() == 3 || grad.dim() == 4);
NVTE_CHECK(cu_seqlens.dtype() == DType::kInt32);
NVTE_CHECK(cu_seqlens.dim() == 1);
auto grad_shape = grad.shape();
auto cu_seqlens_shape = cu_seqlens.shape();
auto grad_per_step_shape = grad_per_step.shape();
// Shape of dq is [t, h, d], so the dimension of "t" is 0
// Shape of dkv is [2, t, h, d], so the dimension of "t" is 1
int seq_dim = grad.dim() == 3 ? 0 : 1;
int total_tokens = grad_shape[seq_dim];
int num_heads = grad_shape[seq_dim + 1];
int dim_per_head = grad_shape[seq_dim + 2];
int batch = cu_seqlens_shape[0] - 1;
if constexpr (functor_idx < 2) {
NVTE_CHECK(grad_per_step_shape[seq_dim] == total_tokens / 2);
} else {
NVTE_CHECK(grad_per_step_shape[seq_dim] == total_tokens);
}
NVTE_CHECK(grad_per_step_shape[seq_dim + 1] == num_heads);
NVTE_CHECK(grad_per_step_shape[seq_dim + 2] == dim_per_head);
size_t hidden_size = num_heads * dim_per_head;
NVTE_CHECK((hidden_size * typeToSize(grad.dtype())) % 16 == 0);
constexpr unsigned int block = 256;
unsigned int grid_x;
if constexpr (functor_idx < 2) {
grid_x = (total_tokens / 2 * 32 + block - 1) / block;
} else {
grid_x = (total_tokens * 32 + block - 1) / block;
}
unsigned int grid_y = 1;
for (int i = 0; i < seq_dim; i++) {
grid_y *= grad_shape[i];
}
dim3 grid = {grid_x, grid_y};
thd_grad_correction_kernel<dtype, Functor_0, Functor_1, functor_idx, 32>
<<<grid, block, sizeof(int) * (batch + 1), stream>>>(
reinterpret_cast<dtype *>(grad.data.dptr),
reinterpret_cast<dtype *>(grad_per_step.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), batch, hidden_size, total_tokens);
}
template <typename dtype>
static void thd_grad_dispatcher(Tensor grad, const Tensor &grad_per_step, const Tensor &cu_seqlens,
const std::string &first_half, const std::string &second_half,
cudaStream_t stream) {
using namespace transformer_engine;
if (first_half == "add" && second_half == "none") {
thd_grad_correction_helper<dtype, AddFunctor<dtype>, EmptyFunctor, 0>(grad, grad_per_step,
cu_seqlens, stream);
} else if (first_half == "copy" && second_half == "none") {
thd_grad_correction_helper<dtype, CopyFunctor, EmptyFunctor, 0>(grad, grad_per_step, cu_seqlens,
stream);
} else if (first_half == "none" && second_half == "add") {
thd_grad_correction_helper<dtype, EmptyFunctor, AddFunctor<dtype>, 1>(grad, grad_per_step,
cu_seqlens, stream);
} else if (first_half == "none" && second_half == "copy") {
thd_grad_correction_helper<dtype, EmptyFunctor, CopyFunctor, 1>(grad, grad_per_step, cu_seqlens,
stream);
} else if (first_half == "add" && second_half == "copy") {
thd_grad_correction_helper<dtype, AddFunctor<dtype>, CopyFunctor, 2>(grad, grad_per_step,
cu_seqlens, stream);
} else if (first_half == "copy" && second_half == "add") {
thd_grad_correction_helper<dtype, CopyFunctor, AddFunctor<dtype>, 2>(grad, grad_per_step,
cu_seqlens, stream);
} else {
NVTE_ERROR("Unsupported Functor of first half and second_half\n");
}
}
void thd_grad_correction(Tensor grad, const Tensor &grad_per_step, const Tensor &cu_seqlens,
const std::string &first_half, const std::string &second_half,
cudaStream_t stream) {
using namespace transformer_engine;
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
grad.dtype(), dtype,
thd_grad_dispatcher<dtype>(grad, grad_per_step, cu_seqlens, first_half, second_half,
stream););
}
/***************************************************************************************************
* Support THD format for Context Parallel: Generate partitioned indices for input tokens
**************************************************************************************************/
void thd_get_partitioned_indices(const Tensor &cu_seqlens, Tensor output, int total_tokens,
int world_size, int rank, cudaStream_t stream) {
using namespace transformer_engine;
NVTE_CHECK(cu_seqlens.dtype() == DType::kInt32);
NVTE_CHECK(cu_seqlens.dim() == 1);
auto cu_seqlens_shape = cu_seqlens.shape();
auto output_shape = output.shape();
NVTE_CHECK(cu_seqlens_shape[0] >= 2);
NVTE_CHECK(rank >= 0 && rank < world_size);
NVTE_CHECK(world_size > 0);
NVTE_CHECK(total_tokens > 0 && total_tokens % (world_size * 2) == 0);
int batch = cu_seqlens_shape[0] - 1;
constexpr unsigned int block = 256;
unsigned int grid = (output_shape[0] + block - 1) / block;
thd_partition_indices_kernel<<<grid, block, sizeof(int) * (batch + 1), stream>>>(
reinterpret_cast<int *>(output.data.dptr), reinterpret_cast<int *>(cu_seqlens.data.dptr),
batch, total_tokens, world_size, rank);
}
} // namespace context_parallel
} // namespace transformer_engine } // namespace transformer_engine
#endif
void nvte_cp_thd_read_half_tensor(const NVTETensor &tensor, const NVTETensor &cu_seqlens,
NVTETensor half, int half_idx, cudaStream_t stream) {
NVTE_API_CALL(nvte_thd_read_half_tensor);
using namespace transformer_engine;
context_parallel::thd_read_half_tensor(*reinterpret_cast<Tensor *>(tensor),
*reinterpret_cast<Tensor *>(cu_seqlens),
*reinterpret_cast<Tensor *>(half), half_idx, stream);
}
void nvte_cp_thd_second_half_lse_correction(NVTETensor lse, const NVTETensor &lse_per_step,
const NVTETensor &cu_seqlens, int lse_packed,
cudaStream_t stream) {
NVTE_API_CALL(nvte_thd_second_half_lse_correction);
using namespace transformer_engine;
context_parallel::thd_second_half_lse_correction(
*reinterpret_cast<Tensor *>(lse), *reinterpret_cast<Tensor *>(lse_per_step),
*reinterpret_cast<Tensor *>(cu_seqlens), lse_packed, stream);
}
void nvte_cp_thd_read_second_half_lse(const NVTETensor &lse, const NVTETensor &cu_seqlens,
NVTETensor half_lse, int lse_packed,
int second_half_lse_seqlen, cudaStream_t stream) {
NVTE_API_CALL(nvte_thd_read_second_half_lse);
using namespace transformer_engine;
context_parallel::thd_read_second_half_lse(
*reinterpret_cast<Tensor *>(lse), *reinterpret_cast<Tensor *>(cu_seqlens),
*reinterpret_cast<Tensor *>(half_lse), lse_packed, second_half_lse_seqlen, stream);
}
void nvte_cp_thd_out_correction(NVTETensor out, const NVTETensor &out_per_step,
const NVTETensor &lse, const NVTETensor &lse_per_step,
const NVTETensor &cu_seqlens, int only_second_half, int lse_packed,
cudaStream_t stream) {
NVTE_API_CALL(nvte_thd_out_correction);
using namespace transformer_engine;
context_parallel::thd_out_correction(
*reinterpret_cast<Tensor *>(out), *reinterpret_cast<Tensor *>(out_per_step),
*reinterpret_cast<Tensor *>(lse), *reinterpret_cast<Tensor *>(lse_per_step),
*reinterpret_cast<Tensor *>(cu_seqlens), only_second_half, lse_packed, stream);
}
void nvte_cp_thd_grad_correction(NVTETensor grad, const NVTETensor &grad_per_step,
const NVTETensor &cu_seqlens, const char *first_half,
const char *second_half, cudaStream_t stream) {
NVTE_API_CALL(nvte_thd_grad_correction);
using namespace transformer_engine;
std::string first_half_str(first_half);
std::string second_half_str(second_half);
context_parallel::thd_grad_correction(
*reinterpret_cast<Tensor *>(grad), *reinterpret_cast<Tensor *>(grad_per_step),
*reinterpret_cast<Tensor *>(cu_seqlens), first_half_str, second_half_str, stream);
}
void nvte_cp_thd_get_partitioned_indices(const NVTETensor &cu_seqlens, NVTETensor output,
int total_tokens, int world_size, int rank,
cudaStream_t stream) {
NVTE_API_CALL(nvte_thd_get_partitioned_indices);
using namespace transformer_engine;
context_parallel::thd_get_partitioned_indices(*reinterpret_cast<Tensor *>(cu_seqlens),
*reinterpret_cast<Tensor *>(output), total_tokens,
world_size, rank, stream);
}
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "../common.h"
#include "transformer_engine/fused_attn.h"
namespace transformer_engine {
namespace flash_attention {
constexpr int warp_size = 32;
constexpr int type_size = 2; // FP16 or BF16
constexpr int nvec = sizeof(uint64_t) / type_size;
constexpr int load_size = warp_size * nvec;
constexpr int block_size = 512;
template <typename T>
__launch_bounds__(block_size) __global__
void prepare_kernel_fwd(const T *qkvi, T *qkv, const size_t B, const size_t S, const size_t Z,
const size_t W) {
const int warpid = (blockDim.x * blockIdx.x + threadIdx.x) / warp_size;
const int id_in_warp = threadIdx.x % warp_size;
const size_t offset_input = blockIdx.y * W + warpid * 3 * W * Z + id_in_warp * nvec;
const T *my_input = qkvi + offset_input;
const size_t s = warpid / B;
if (s >= S) return;
const size_t b = warpid % B;
const size_t offset_output = blockIdx.y * B * S * Z * W + (s + b * S) * W * Z + id_in_warp * nvec;
T *my_output = qkv + offset_output;
for (int i = 0; i < Z; ++i) {
uint64_t *out = reinterpret_cast<uint64_t *>(my_output + i * load_size);
*out = *reinterpret_cast<const uint64_t *>(my_input + i * load_size * 3);
}
}
template <typename T>
__launch_bounds__(block_size) __global__
void prepare_kernel_bwd(const T *q, const T *k, const T *v, T *qkv, const size_t B,
const size_t S, const size_t Z, const size_t W) {
const T *input = blockIdx.y == 0 ? q : (blockIdx.y == 1 ? k : v);
const int warpid = (blockDim.x * blockIdx.x + threadIdx.x) / warp_size;
const int id_in_warp = threadIdx.x % warp_size;
const size_t offset_input = warpid * W * Z + id_in_warp * nvec;
const T *my_input = input + offset_input;
const size_t b = warpid / S;
if (b >= B) return;
const size_t s = warpid % S;
const size_t offset_output = (b + s * B) * 3 * W * Z + id_in_warp * nvec + blockIdx.y * W;
T *my_output = qkv + offset_output;
for (int i = 0; i < Z; ++i) {
uint64_t *out = reinterpret_cast<uint64_t *>(my_output + i * load_size * 3);
*out = *reinterpret_cast<const uint64_t *>(my_input + i * load_size);
}
}
void prepare_flash_attn_fwd(Tensor qkvi, Tensor qkv, cudaStream_t stream) {
using namespace transformer_engine;
NVTE_CHECK(qkvi.dim() == 4, "Expected 4-dim tensor.");
NVTE_CHECK(qkvi.dtype() == DType::kFloat16 || qkvi.dtype() == DType::kBFloat16);
auto qkvi_shape = qkvi.shape();
NVTE_CHECK(qkvi_shape[3] % load_size == 0);
NVTE_CHECK(qkvi_shape[3] == load_size);
// [s, b, n, h * 3] -> [3, b, s, n, h]
std::vector<uint64_t> shape = {3, qkvi_shape[1], qkvi_shape[0], qkvi_shape[2], qkvi_shape[3]};
size_t warps = qkvi_shape[0] * qkvi_shape[1];
size_t warps_per_block = block_size / warp_size;
size_t blocks = (warps + warps_per_block - 1) / warps_per_block;
dim3 grid(blocks, 3);
int threads = block_size;
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(
qkvi.dtype(), dtype,
prepare_kernel_fwd<dtype><<<grid, threads, 0, stream>>>(
reinterpret_cast<dtype *>(qkvi.data.dptr), reinterpret_cast<dtype *>(qkv.data.dptr),
shape[1], shape[2], shape[3], shape[4]););
}
void prepare_flash_attn_bwd(Tensor q, Tensor k, Tensor v, Tensor qkv, cudaStream_t stream) {
using namespace transformer_engine;
NVTE_CHECK(q.dim() == 4, "Expected 4-dim tensor.");
NVTE_CHECK(k.dim() == 4, "Expected 4-dim tensor.");
NVTE_CHECK(v.dim() == 4, "Expected 4-dim tensor.");
NVTE_CHECK(q.dtype() == DType::kFloat16 || q.dtype() == DType::kBFloat16);
NVTE_CHECK(k.dtype() == q.dtype());
NVTE_CHECK(v.dtype() == q.dtype());
auto q_shape = q.shape();
auto k_shape = k.shape();
auto v_shape = v.shape();
NVTE_CHECK(q_shape[3] % load_size == 0);
NVTE_CHECK(q_shape[3] == load_size);
NVTE_CHECK(k_shape[3] % load_size == 0);
NVTE_CHECK(k_shape[3] == load_size);
NVTE_CHECK(v_shape[3] % load_size == 0);
NVTE_CHECK(v_shape[3] == load_size);
// 3 x [s, b, n, h] -> [b, s, n, 3 * h]
std::vector<uint64_t> shape = {q_shape[1], q_shape[0], q_shape[2], 3 * q_shape[3]};
size_t warps = q_shape[0] * q_shape[1];
size_t warps_per_block = block_size / warp_size;
size_t blocks = (warps + warps_per_block - 1) / warps_per_block;
dim3 grid(blocks, 3);
int threads = block_size;
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(
q.dtype(), dtype,
prepare_kernel_bwd<dtype><<<grid, threads, 0, stream>>>(
reinterpret_cast<dtype *>(q.data.dptr), reinterpret_cast<dtype *>(k.data.dptr),
reinterpret_cast<dtype *>(v.data.dptr), reinterpret_cast<dtype *>(qkv.data.dptr),
q_shape[0], q_shape[1], q_shape[2], q_shape[3]););
}
} // namespace flash_attention
} // namespace transformer_engine
void nvte_prepare_flash_attn_fwd(NVTETensor qkvi, NVTETensor qkv, cudaStream_t stream) {
NVTE_API_CALL(nvte_prepare_flash_attn_fwd);
using namespace transformer_engine;
flash_attention::prepare_flash_attn_fwd(*reinterpret_cast<Tensor *>(qkvi),
*reinterpret_cast<Tensor *>(qkv), stream);
}
void nvte_prepare_flash_attn_bwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTETensor qkv,
cudaStream_t stream) {
NVTE_API_CALL(nvte_prepare_flash_attn_bwd);
using namespace transformer_engine;
flash_attention::prepare_flash_attn_bwd(
*reinterpret_cast<Tensor *>(q), *reinterpret_cast<Tensor *>(k),
*reinterpret_cast<Tensor *>(v), *reinterpret_cast<Tensor *>(qkv), stream);
}
...@@ -3,48 +3,15 @@ ...@@ -3,48 +3,15 @@
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_KV_CACHE_CUH_
#define TRANSFORMER_ENGINE_FUSED_ATTN_KV_CACHE_CUH_
namespace transformer_engine { #include "../common.h"
namespace fused_attn { #include "transformer_engine/fused_attn.h"
template <typename scalar_t>
__global__ void convert_thd_to_bshd_kernel(scalar_t *tensor, scalar_t *new_tensor, int *cu_seqlens,
int b, int max_seq_len, int h, int d) {
// tensor: thd; new_tensor: bshd
// cu_seqlens: [b + 1]
for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
int num_elts = (cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx]) * h * d;
int thd_offset = cu_seqlens[batch_idx] * h * d;
int bshd_offset = batch_idx * max_seq_len * h * d;
scalar_t *thd_token = tensor + thd_offset;
scalar_t *bshd_token = new_tensor + bshd_offset;
for (int i = threadIdx.x; i < num_elts; i += blockDim.x) {
*(bshd_token + i) = *(thd_token + i);
}
}
}
template <typename scalar_t> namespace transformer_engine {
__global__ void convert_bshd_to_thd_kernel(scalar_t *tensor, scalar_t *new_tensor, int *cu_seqlens, namespace kv_cache {
int b, int max_seq_len, int h, int d) {
// tensor: bshd; new_tensor: thd
// cu_seqlens: [b + 1]
for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
int seqlen = cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx];
int num_elts = seqlen * h * d;
int bshd_offset = batch_idx * max_seq_len * h * d;
int thd_offset = cu_seqlens[batch_idx] * h * d;
scalar_t *bshd_token = tensor + bshd_offset;
scalar_t *thd_token = new_tensor + thd_offset;
for (int i = threadIdx.x; i < num_elts; i += blockDim.x) {
*(thd_token + i) = *(bshd_token + i);
}
}
}
template <typename scalar_t> template <typename dtype>
__global__ void reindex_kv_cache_kernel(scalar_t *k_cache, scalar_t *v_cache, int *batch_indices, __global__ void reindex_kv_cache_kernel(dtype *k_cache, dtype *v_cache, int *batch_indices,
int *cu_new_lens, int *cu_cached_lens, int h_kv, int d_k, int *cu_new_lens, int *cu_cached_lens, int h_kv, int d_k,
int d_v, int b, int max_seq_len) { int d_v, int b, int max_seq_len) {
// k_cache, v_cache: bshd // k_cache, v_cache: bshd
...@@ -75,11 +42,11 @@ __global__ void reindex_kv_cache_kernel(scalar_t *k_cache, scalar_t *v_cache, in ...@@ -75,11 +42,11 @@ __global__ void reindex_kv_cache_kernel(scalar_t *k_cache, scalar_t *v_cache, in
} }
} }
template <typename scalar_t> template <typename dtype>
__global__ void copy_to_kv_cache_kernel(scalar_t *new_k, scalar_t *new_v, scalar_t *k_cache, __global__ void copy_to_kv_cache_kernel(dtype *new_k, dtype *new_v, dtype *k_cache, dtype *v_cache,
scalar_t *v_cache, int *page_table, int *cu_new_lens, int *page_table, int *cu_new_lens, int *cu_cached_lens,
int *cu_cached_lens, NVTE_QKV_Format qkv_format, int h_kv, NVTE_QKV_Format qkv_format, int h_kv, int d_k, int d_v,
int d_k, int d_v, int b, int max_ctx_len, int max_seq_len, int b, int max_ctx_len, int max_seq_len,
int max_pages_per_seq, bool is_non_paged) { int max_pages_per_seq, bool is_non_paged) {
// new_k, new_v: qkv_format; k_cache, v_cache: bshd // new_k, new_v: qkv_format; k_cache, v_cache: bshd
// cu_new_lens, cu_cached_lens: [b + 1] // cu_new_lens, cu_cached_lens: [b + 1]
...@@ -140,6 +107,191 @@ __global__ void copy_to_kv_cache_kernel(scalar_t *new_k, scalar_t *new_v, scalar ...@@ -140,6 +107,191 @@ __global__ void copy_to_kv_cache_kernel(scalar_t *new_k, scalar_t *new_v, scalar
} }
} }
} }
} // namespace fused_attn
template <typename dtype>
void copy_to_kv_cache_launcher(Tensor new_k, Tensor new_v, Tensor k_cache, Tensor v_cache,
Tensor page_table, Tensor cu_new_lens, Tensor cu_cached_lens,
NVTE_QKV_Format qkv_format, int h_kv, int d_k, int d_v, int b,
int max_ctx_len, int max_seq_len, int max_pages_per_seq,
bool is_non_paged, cudaStream_t stream) {
if (new_k.has_data() && new_v.has_data() && k_cache.has_data() && v_cache.has_data()) {
if (is_non_paged) {
reindex_kv_cache_kernel<<<16, 256, 0, stream>>>(
reinterpret_cast<dtype *>(k_cache.data.dptr),
reinterpret_cast<dtype *>(v_cache.data.dptr),
reinterpret_cast<int *>(page_table.data.dptr),
reinterpret_cast<int *>(cu_new_lens.data.dptr),
reinterpret_cast<int *>(cu_cached_lens.data.dptr), h_kv, d_k, d_v, b, max_seq_len);
}
copy_to_kv_cache_kernel<<<16, 256, 0, stream>>>(
reinterpret_cast<dtype *>(new_k.data.dptr), reinterpret_cast<dtype *>(new_v.data.dptr),
reinterpret_cast<dtype *>(k_cache.data.dptr), reinterpret_cast<dtype *>(v_cache.data.dptr),
reinterpret_cast<int *>(page_table.data.dptr),
reinterpret_cast<int *>(cu_new_lens.data.dptr),
reinterpret_cast<int *>(cu_cached_lens.data.dptr), qkv_format, h_kv, d_k, d_v, b,
max_ctx_len, max_seq_len, max_pages_per_seq, is_non_paged);
}
}
void copy_to_kv_cache(Tensor new_k, Tensor new_v, Tensor k_cache, Tensor v_cache, Tensor page_table,
Tensor cu_new_lens, Tensor cu_cached_lens, NVTE_QKV_Format qkv_format, int b,
int max_ctx_len, int max_seq_len, int max_pages_per_seq, bool is_non_paged,
cudaStream_t stream) {
int h_kv = new_k.shape()[new_k.dim() - 2];
int d_k = new_k.shape()[new_k.dim() - 1];
int d_v = new_v.shape()[new_v.dim() - 1];
NVTE_CHECK(k_cache.dtype() == v_cache.dtype() && new_k.dtype() == new_v.dtype() &&
new_k.dtype() == k_cache.dtype(),
"new_k, new_v, k_cache and v_cache must be of the same data type.");
NVTE_CHECK(qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD ||
qkv_format == NVTE_QKV_Format::NVTE_THD,
"qkv_format must be {BSHD, SBHD, THD}.");
TRANSFORMER_ENGINE_TYPE_SWITCH_FLOAT(
k_cache.dtype(), dtype,
copy_to_kv_cache_launcher<dtype>(new_k, new_v, k_cache, v_cache, page_table, cu_new_lens,
cu_cached_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len,
max_seq_len, max_pages_per_seq, is_non_paged, stream););
}
template <typename scalar_t>
__global__ void convert_thd_to_bshd_kernel(scalar_t *tensor, scalar_t *new_tensor, int *cu_seqlens,
int b, int max_seq_len, int h, int d) {
// tensor: thd; new_tensor: bshd
// cu_seqlens: [b + 1]
for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
int num_elts = (cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx]) * h * d;
int thd_offset = cu_seqlens[batch_idx] * h * d;
int bshd_offset = batch_idx * max_seq_len * h * d;
scalar_t *thd_token = tensor + thd_offset;
scalar_t *bshd_token = new_tensor + bshd_offset;
for (int i = threadIdx.x; i < num_elts; i += blockDim.x) {
*(bshd_token + i) = *(thd_token + i);
}
}
}
template <typename scalar_t>
void convert_thd_to_bshd_launcher(Tensor tensor, Tensor new_tensor, Tensor cu_seqlens, int b,
int max_seq_len, int h, int d, cudaStream_t stream) {
using namespace transformer_engine;
convert_thd_to_bshd_kernel<<<16, 256, 0, stream>>>(
reinterpret_cast<scalar_t *>(tensor.data.dptr),
reinterpret_cast<scalar_t *>(new_tensor.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), b, max_seq_len, h, d);
}
void convert_thd_to_bshd(Tensor tensor, Tensor cu_seqlens, Tensor new_tensor, int b,
int max_seq_len, cudaStream_t stream) {
using namespace transformer_engine;
auto tensor_shape = tensor.shape();
TRANSFORMER_ENGINE_TYPE_SWITCH_FLOAT(
new_tensor.dtype(), dtype,
convert_thd_to_bshd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len,
tensor_shape[1], tensor_shape[2], stream););
}
template <typename scalar_t>
__global__ void convert_bshd_to_thd_kernel(scalar_t *tensor, scalar_t *new_tensor, int *cu_seqlens,
int b, int max_seq_len, int h, int d) {
// tensor: bshd; new_tensor: thd
// cu_seqlens: [b + 1]
for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
int seqlen = cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx];
int num_elts = seqlen * h * d;
int bshd_offset = batch_idx * max_seq_len * h * d;
int thd_offset = cu_seqlens[batch_idx] * h * d;
scalar_t *bshd_token = tensor + bshd_offset;
scalar_t *thd_token = new_tensor + thd_offset;
for (int i = threadIdx.x; i < num_elts; i += blockDim.x) {
*(thd_token + i) = *(bshd_token + i);
}
}
}
template <typename scalar_t>
void convert_bshd_to_thd_launcher(Tensor tensor, Tensor new_tensor, Tensor cu_seqlens, int b,
int max_seq_len, int h, int d, cudaStream_t stream) {
using namespace transformer_engine;
convert_bshd_to_thd_kernel<<<16, 256, 0, stream>>>(
reinterpret_cast<scalar_t *>(tensor.data.dptr),
reinterpret_cast<scalar_t *>(new_tensor.data.dptr),
reinterpret_cast<int *>(cu_seqlens.data.dptr), b, max_seq_len, h, d);
}
void convert_bshd_to_thd(Tensor tensor, Tensor cu_seqlens, Tensor new_tensor, int t,
cudaStream_t stream) {
using namespace transformer_engine;
auto tensor_shape = tensor.shape();
TRANSFORMER_ENGINE_TYPE_SWITCH_FLOAT(
tensor.dtype(), dtype,
convert_bshd_to_thd_launcher<dtype>(tensor, new_tensor, cu_seqlens, tensor_shape[0],
tensor_shape[1], tensor_shape[2], tensor_shape[3],
stream););
}
} // namespace kv_cache
} // namespace transformer_engine } // namespace transformer_engine
#endif
/***************************************************************************************************
* KV Cache: Copy new KV tokens to the KV cache
* 1. new_k and new_v are in qkv_format; k_cache and v_cache are in 'bshd' format
* 2. cu_new_lens and cu_cached_lens are in shape [b + 1]; cu_cached_lens include the added lens
* in current step
* 3. Non-paged KV cache is a special case of paged KV cache, with page_table = [b, 1] and
* max_pages_per_seq = 1. We use the same underlying kernel for both non-paged and paged.
* Set is_non_paged = True/False to indicate as such.
* 4. is_non_paged = True also re-indexes the KV cache, e.g. the initial batch indices [0, 3, 1, 2]
* becomes [0, 1, 1, 2]. The page_table = batch_indices.unsqueeze(1) is however unchanged.
* batch_indices_post can be used for monotonical indexing, i.e. [0, 1, 2, 3]. batch_indices is
* preserved for the next layer in the same iteration.
* 5. Only supports same page_table for k_cache and v_cache
* 6. Only pad_between_seqs = False when qkv_format = thd, i.e. there should be no pad tokens
* between sequences in new_k and new_v such as [a a a 0..0 b b 0..0 c 0..0].
**************************************************************************************************/
void nvte_copy_to_kv_cache(NVTETensor new_k, NVTETensor new_v, NVTETensor k_cache,
NVTETensor v_cache, NVTETensor page_table, NVTETensor cu_new_lens,
NVTETensor cu_cached_lens, NVTE_QKV_Format qkv_format, int b,
int max_ctx_len, int max_seq_len, int max_pages_per_seq,
int is_non_paged, cudaStream_t stream) {
NVTE_API_CALL(nvte_copy_to_kv_cache);
using namespace transformer_engine;
kv_cache::copy_to_kv_cache(
*reinterpret_cast<Tensor *>(new_k), *reinterpret_cast<Tensor *>(new_v),
*reinterpret_cast<Tensor *>(k_cache), *reinterpret_cast<Tensor *>(v_cache),
*reinterpret_cast<Tensor *>(page_table), *reinterpret_cast<Tensor *>(cu_new_lens),
*reinterpret_cast<Tensor *>(cu_cached_lens), qkv_format, b, max_ctx_len, max_seq_len,
max_pages_per_seq, is_non_paged, stream);
}
/***************************************************************************************************
* KV Cache: Convert a tensor from qkv_format = thd to qkv_format = bshd
**************************************************************************************************/
void nvte_convert_thd_to_bshd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETensor new_tensor,
int b, int max_seq_len, cudaStream_t stream) {
NVTE_API_CALL(nvte_convert_thd_to_bshd);
using namespace transformer_engine;
kv_cache::convert_thd_to_bshd(*reinterpret_cast<Tensor *>(tensor),
*reinterpret_cast<Tensor *>(cu_seqlens),
*reinterpret_cast<Tensor *>(new_tensor), b, max_seq_len, stream);
}
/***************************************************************************************************
* KV Cache: Convert a tensor from qkv_format = bshd to qkv_format = thd
**************************************************************************************************/
void nvte_convert_bshd_to_thd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETensor new_tensor,
int t, cudaStream_t stream) {
NVTE_API_CALL(nvte_convert_bshd_to_thd);
using namespace transformer_engine;
kv_cache::convert_bshd_to_thd(*reinterpret_cast<Tensor *>(tensor),
*reinterpret_cast<Tensor *>(cu_seqlens),
*reinterpret_cast<Tensor *>(new_tensor), t, stream);
}
...@@ -610,5 +610,27 @@ uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cud ...@@ -610,5 +610,27 @@ uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cud
return hout; return hout;
} }
__global__ void extract_seed_and_offset(int64_t *rng_state_ptr, bool captured, int64_t *seed_ptr,
uint64_t seed_val, int64_t *offset_ptr, uint64_t offset_val,
uint32_t offset_intragraph) {
if (captured) {
rng_state_ptr[0] = *seed_ptr;
rng_state_ptr[1] = static_cast<int64_t>(*offset_ptr + static_cast<int64_t>(offset_intragraph));
} else {
rng_state_ptr[0] = static_cast<int64_t>(seed_val);
rng_state_ptr[1] = static_cast<int64_t>(offset_val);
}
}
} // namespace fused_attn } // namespace fused_attn
} // namespace transformer_engine } // namespace transformer_engine
void nvte_extract_seed_and_offset(int64_t *rng_state_ptr, int captured, int64_t *seed_ptr,
uint64_t seed_val, int64_t *offset_ptr, uint64_t offset_val,
uint32_t offset_intragraph, cudaStream_t stream) {
NVTE_API_CALL(nvte_extract_seed_and_offset);
using namespace transformer_engine;
fused_attn::extract_seed_and_offset<<<1, 1, 0, stream>>>(
rng_state_ptr, captured, seed_ptr, seed_val, offset_ptr, offset_val, offset_intragraph);
}
...@@ -244,7 +244,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -244,7 +244,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S, void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S,
NVTETensor O, NVTETensorPack* Aux_CTX_Tensors, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded,
const NVTETensor rng_state, size_t max_seqlen, bool is_training, const NVTETensor rng_state, size_t max_seqlen, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
...@@ -300,7 +300,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, ...@@ -300,7 +300,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
*/ */
void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO,
const NVTETensor S, NVTETensor dP, const NVTETensor S, NVTETensor dP,
const NVTETensorPack* Aux_CTX_Tensors, NVTETensor dQKV, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV,
NVTETensor dBias, const NVTETensor cu_seqlens, NVTETensor dBias, const NVTETensor cu_seqlens,
const NVTETensor cu_seqlens_padded, size_t max_seqlen, const NVTETensor cu_seqlens_padded, size_t max_seqlen,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
...@@ -368,7 +368,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con ...@@ -368,7 +368,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
*/ */
void nvte_fused_attn_fwd_kvpacked( void nvte_fused_attn_fwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O, const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O,
NVTETensorPack* Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout,
...@@ -429,7 +429,7 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -429,7 +429,7 @@ void nvte_fused_attn_fwd_kvpacked(
*/ */
void nvte_fused_attn_bwd_kvpacked( void nvte_fused_attn_bwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO, const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO,
const NVTETensor S, NVTETensor dP, const NVTETensorPack* Aux_CTX_Tensors, NVTETensor dQ, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ,
NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout,
...@@ -500,7 +500,7 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -500,7 +500,7 @@ void nvte_fused_attn_bwd_kvpacked(
*/ */
void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor Bias, NVTETensor S, NVTETensor O, const NVTETensor Bias, NVTETensor S, NVTETensor O,
NVTETensorPack* Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
const NVTETensor page_table_v, const NVTETensor rng_state, const NVTETensor page_table_v, const NVTETensor rng_state,
...@@ -569,7 +569,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -569,7 +569,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
*/ */
void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP,
const NVTETensorPack* Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK,
NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q, NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q,
...@@ -604,6 +604,51 @@ void nvte_populate_rng_state_async(NVTETensor rng_state_dst, const NVTETensor se ...@@ -604,6 +604,51 @@ void nvte_populate_rng_state_async(NVTETensor rng_state_dst, const NVTETensor se
uint32_t nvte_get_runtime_num_segments(NVTETensor cu_seqlen, NVTETensor workspace, size_t len, uint32_t nvte_get_runtime_num_segments(NVTETensor cu_seqlen, NVTETensor workspace, size_t len,
cudaStream_t stream); cudaStream_t stream);
void nvte_extract_seed_and_offset(int64_t *rng_state_ptr, int captured, int64_t *seed_ptr,
uint64_t seed_val, int64_t *offset_ptr, uint64_t offset_val,
uint32_t offset_intragraph, cudaStream_t stream);
void nvte_copy_to_kv_cache(NVTETensor new_k, NVTETensor new_v, NVTETensor k_cache,
NVTETensor v_cache, NVTETensor page_table, NVTETensor cu_new_lens,
NVTETensor cu_cached_lens, NVTE_QKV_Format qkv_format, int b,
int max_ctx_len, int max_seq_len, int max_pages_per_seq,
int is_non_paged, cudaStream_t stream);
void nvte_cp_thd_read_half_tensor(const NVTETensor &tensor, const NVTETensor &cu_seqlens,
NVTETensor half, int half_idx, cudaStream_t stream);
void nvte_cp_thd_second_half_lse_correction(NVTETensor lse, const NVTETensor &lse_per_step,
const NVTETensor &cu_seqlens, int lse_packed,
cudaStream_t stream);
void nvte_cp_thd_read_second_half_lse(const NVTETensor &lse, const NVTETensor &cu_seqlens,
NVTETensor half_lse, int lse_packed,
int second_half_lse_seqlen, cudaStream_t stream);
void nvte_cp_thd_out_correction(NVTETensor out, const NVTETensor &out_per_step,
const NVTETensor &lse, const NVTETensor &lse_per_step,
const NVTETensor &cu_seqlens, int only_second_half, int lse_packed,
cudaStream_t stream);
void nvte_cp_thd_grad_correction(NVTETensor grad, const NVTETensor &grad_per_step,
const NVTETensor &cu_seqlens, const char *first_half,
const char *second_half, cudaStream_t stream);
void nvte_cp_thd_get_partitioned_indices(const NVTETensor &cu_seqlens, NVTETensor output,
int total_tokens, int world_size, int rank,
cudaStream_t stream);
void nvte_convert_thd_to_bshd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETensor new_tensor,
int b, int max_seq_len, cudaStream_t stream);
void nvte_convert_bshd_to_thd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETensor new_tensor,
int t, cudaStream_t stream);
void nvte_prepare_flash_attn_fwd(NVTETensor qkvi, NVTETensor qkv, cudaStream_t stream);
void nvte_prepare_flash_attn_bwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTETensor qkv,
cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
......
...@@ -218,6 +218,26 @@ std::vector<size_t> getTensorShape(at::Tensor t); ...@@ -218,6 +218,26 @@ std::vector<size_t> getTensorShape(at::Tensor t);
transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid,
const std::string& fp8_recipe); const std::string& fp8_recipe);
inline size_t typeToSize(transformer_engine::DType t) {
switch (t) {
case transformer_engine::DType::kInt64:
return 8;
case transformer_engine::DType::kInt32:
case transformer_engine::DType::kFloat32:
return 4;
case transformer_engine::DType::kInt16:
case transformer_engine::DType::kFloat16:
case transformer_engine::DType::kBFloat16:
return 2;
case transformer_engine::DType::kByte:
case transformer_engine::DType::kFloat8E4M3:
case transformer_engine::DType::kFloat8E5M2:
return 1;
default:
NVTE_ERROR("Invalid type");
}
}
inline at::ScalarType GetATenDType(transformer_engine::DType t) { inline at::ScalarType GetATenDType(transformer_engine::DType t) {
switch (t) { switch (t) {
case transformer_engine::DType::kInt16: case transformer_engine::DType::kInt16:
......
...@@ -72,10 +72,10 @@ at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); ...@@ -72,10 +72,10 @@ at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v);
at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len); at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len);
at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int t); at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int t);
void copy_to_kv_cache(torch::Tensor new_k, torch::Tensor new_v, torch::Tensor k_cache, void copy_to_kv_cache(at::Tensor new_k, at::Tensor new_v, at::Tensor k_cache, at::Tensor v_cache,
torch::Tensor v_cache, torch::Tensor page_table, torch::Tensor cu_new_lens, at::Tensor page_table, at::Tensor cu_new_lens, at::Tensor cu_cached_lens,
torch::Tensor cu_cached_lens, NVTE_QKV_Format kv_format, int b, NVTE_QKV_Format kv_format, int b, int max_ctx_len, int max_seq_len,
int max_ctx_len, int max_seq_len, int max_pages_per_seq, bool is_non_paged); int max_pages_per_seq, bool is_non_paged);
/*************************************************************************************************** /***************************************************************************************************
* GEMM * GEMM
...@@ -392,12 +392,11 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, ...@@ -392,12 +392,11 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
namespace nvshmem_api { namespace nvshmem_api {
void init_nvshmem_backend(c10d::ProcessGroup *process_group); void init_nvshmem_backend(c10d::ProcessGroup *process_group);
torch::Tensor create_nvshmem_tensor(const std::vector<int64_t> &shape, c10::ScalarType dtype); at::Tensor create_nvshmem_tensor(const std::vector<int64_t> &shape, c10::ScalarType dtype);
void nvshmem_send_on_current_stream(torch::Tensor src, torch::Tensor dst, int peer, void nvshmem_send_on_current_stream(at::Tensor src, at::Tensor dst, int peer, at::Tensor signal);
torch::Tensor signal);
void nvshmem_wait_on_current_stream(torch::Tensor signal, const std::string &wait_kind); void nvshmem_wait_on_current_stream(at::Tensor signal, const std::string &wait_kind);
void nvshmem_finalize(); void nvshmem_finalize();
} // namespace nvshmem_api } // namespace nvshmem_api
......
...@@ -5,12 +5,8 @@ ...@@ -5,12 +5,8 @@
************************************************************************/ ************************************************************************/
#include "extensions.h" #include "extensions.h"
#include "kv_cache.cuh"
#include "thd_utils.cuh"
#include "transformer_engine/transformer_engine.h"
constexpr int block_size = 512; constexpr int block_size = 512;
constexpr int ctas_per_sm = 4;
// get the fused attention backend // get the fused attention backend
NVTE_Fused_Attn_Backend get_fused_attn_backend( NVTE_Fused_Attn_Backend get_fused_attn_backend(
...@@ -26,19 +22,6 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( ...@@ -26,19 +22,6 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend(
return fused_attention_backend; return fused_attention_backend;
} }
// fast zero-fills of tensors
template <typename scalar_t>
__global__ void __launch_bounds__(block_size)
mha_fill_kernel(scalar_t *out_tensor, const int32_t *const start_row, const size_t num_rows) {
size_t row_stride = gridDim.y * blockDim.x;
size_t row_index = blockIdx.x + static_cast<size_t>(start_row[0]);
size_t col_index = blockIdx.y * blockDim.x + threadIdx.x;
while (row_index < num_rows) {
out_tensor[row_index * row_stride + col_index] = 0;
row_index += gridDim.x;
}
}
// fast zero-fills of tensors // fast zero-fills of tensors
void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &start_index) { void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &start_index) {
std::vector<size_t> shape = transformer_engine::pytorch::convertShape(self.shape()); std::vector<size_t> shape = transformer_engine::pytorch::convertShape(self.shape());
...@@ -48,33 +31,23 @@ void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &s ...@@ -48,33 +31,23 @@ void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &s
for (int i = 1; i <= shape.size(); i++) { for (int i = 1; i <= shape.size(); i++) {
fcd_size *= shape[i]; fcd_size *= shape[i];
} }
TORCH_CHECK(fcd_size % block_size == 0, "input size not aligned to block size");
const int num_mp = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; NVTE_CHECK(fcd_size % block_size == 0, "input size not aligned to block size");
uint64_t num_blk_y = (uint64_t)(fcd_size / block_size);
uint64_t num_blk_x = (uint64_t)((num_mp * ctas_per_sm + num_blk_y - 1) / num_blk_y); size_t element_size = transformer_engine::pytorch::typeToSize(self.dtype());
dim3 dim_grid(num_blk_x, num_blk_y); int32_t start_row = start_index.data_ptr<int32_t>()[0];
dim3 dim_block(block_size); void *base_ptr = static_cast<char *>(self.get_rowwise_data().data_ptr) +
// trzeba jakos przekonwertowac DType na scalar_type static_cast<size_t>(start_row) * fcd_size * element_size;
at::ScalarType scalar_type = transformer_engine::pytorch::GetATenDType(self.dtype()); size_t num_rows_to_zero = max_tokens - start_row;
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( size_t total_bytes = num_rows_to_zero * fcd_size * element_size;
at::ScalarType::Half, at::ScalarType::BFloat16, scalar_type, "mha_fill", [&]() {
mha_fill_kernel<<<dim_grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>( nvte_memset(base_ptr, 0, total_bytes, at::cuda::getCurrentCUDAStream());
static_cast<scalar_t *>(self.get_rowwise_data().data_ptr),
static_cast<int32_t *>(start_index.data_ptr()), max_tokens);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
} }
// extract seed and offset from PhiloxCudaState void unpack(at::PhiloxCudaState arg, int64_t *rng_state_ptr) {
__global__ void unpack(at::PhiloxCudaState arg, int64_t *rng_state_ptr) { nvte_extract_seed_and_offset(rng_state_ptr, arg.captured_, arg.seed_.ptr, arg.seed_.val,
if (arg.captured_) { arg.offset_.ptr, arg.offset_.val, arg.offset_intragraph_,
rng_state_ptr[0] = static_cast<int64_t>(*arg.seed_.ptr); at::cuda::getCurrentCUDAStream());
rng_state_ptr[1] =
static_cast<int64_t>(*(arg.offset_.ptr) + static_cast<int64_t>(arg.offset_intragraph_));
} else {
rng_state_ptr[0] = static_cast<int64_t>(arg.seed_.val);
rng_state_ptr[1] = static_cast<int64_t>(arg.offset_.val);
}
} }
// extract PhiloxCudaState from CUDA random number generator // extract PhiloxCudaState from CUDA random number generator
...@@ -193,8 +166,7 @@ std::vector<py::object> fused_attn_fwd( ...@@ -193,8 +166,7 @@ std::vector<py::object> fused_attn_fwd(
rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); rng_gen, at::cuda::detail::getDefaultCUDAGenerator());
at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread);
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
unpack<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>( unpack(philox_args, static_cast<int64_t *>(rng_state.data_ptr()));
philox_args, static_cast<int64_t *>(rng_state.data_ptr()));
auto te_rng_state = makeTransformerEngineTensor(rng_state); auto te_rng_state = makeTransformerEngineTensor(rng_state);
// create auxiliary output tensors // create auxiliary output tensors
...@@ -512,72 +484,13 @@ std::vector<py::object> fused_attn_bwd( ...@@ -512,72 +484,13 @@ std::vector<py::object> fused_attn_bwd(
return {py_dQ, py_dK, py_dV, py::cast(dBias)}; return {py_dQ, py_dK, py_dV, py::cast(dBias)};
} }
namespace flash_attention {
constexpr int warp_size = 32;
constexpr int type_size = 2; // FP16 or BF16
constexpr int nvec = sizeof(uint64_t) / type_size;
constexpr int load_size = warp_size * nvec;
constexpr int block_size = 512;
template <typename T>
__launch_bounds__(block_size) __global__
void prepare_kernel_fwd(const T *qkvi, T *qkv, const size_t B, const size_t S, const size_t Z,
const size_t W) {
const int warpid = (blockDim.x * blockIdx.x + threadIdx.x) / warp_size;
const int id_in_warp = threadIdx.x % warp_size;
const size_t offset_input = blockIdx.y * W + warpid * 3 * W * Z + id_in_warp * nvec;
const T *my_input = qkvi + offset_input;
const size_t s = warpid / B;
if (s >= S) return;
const size_t b = warpid % B;
const size_t offset_output = blockIdx.y * B * S * Z * W + (s + b * S) * W * Z + id_in_warp * nvec;
T *my_output = qkv + offset_output;
for (int i = 0; i < Z; ++i) {
uint64_t *out = reinterpret_cast<uint64_t *>(my_output + i * load_size);
*out = *reinterpret_cast<const uint64_t *>(my_input + i * load_size * 3);
}
}
template <typename T>
__launch_bounds__(block_size) __global__
void prepare_kernel_bwd(const T *q, const T *k, const T *v, T *qkv, const size_t B,
const size_t S, const size_t Z, const size_t W) {
const T *input = blockIdx.y == 0 ? q : (blockIdx.y == 1 ? k : v);
const int warpid = (blockDim.x * blockIdx.x + threadIdx.x) / warp_size;
const int id_in_warp = threadIdx.x % warp_size;
const size_t offset_input = warpid * W * Z + id_in_warp * nvec;
const T *my_input = input + offset_input;
const size_t b = warpid / S;
if (b >= B) return;
const size_t s = warpid % S;
const size_t offset_output = (b + s * B) * 3 * W * Z + id_in_warp * nvec + blockIdx.y * W;
T *my_output = qkv + offset_output;
for (int i = 0; i < Z; ++i) {
uint64_t *out = reinterpret_cast<uint64_t *>(my_output + i * load_size * 3);
*out = *reinterpret_cast<const uint64_t *>(my_input + i * load_size);
}
}
} // namespace flash_attention
at::Tensor fa_prepare_fwd(at::Tensor qkvi) { at::Tensor fa_prepare_fwd(at::Tensor qkvi) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
NVTE_CHECK(qkvi.dim() == 4, "Expected 4-dim tensor."); NVTE_CHECK(qkvi.dim() == 4, "Expected 4-dim tensor.");
NVTE_CHECK(qkvi.scalar_type() == at::ScalarType::Half || NVTE_CHECK(qkvi.scalar_type() == at::ScalarType::Half ||
qkvi.scalar_type() == at::ScalarType::BFloat16); qkvi.scalar_type() == at::ScalarType::BFloat16);
NVTE_CHECK(qkvi.size(3) % flash_attention::load_size == 0);
NVTE_CHECK(qkvi.size(3) == flash_attention::load_size);
NVTE_CHECK(qkvi.stride(3) == 1, "Wrong stride."); NVTE_CHECK(qkvi.stride(3) == 1, "Wrong stride.");
NVTE_CHECK(qkvi.stride(2) == 3 * qkvi.size(3), "Wrong stride."); NVTE_CHECK(qkvi.stride(2) == 3 * qkvi.size(3), "Wrong stride.");
NVTE_CHECK(qkvi.stride(1) == 3 * qkvi.size(3) * qkvi.size(2), "Wrong stride."); NVTE_CHECK(qkvi.stride(1) == 3 * qkvi.size(3) * qkvi.size(2), "Wrong stride.");
...@@ -587,27 +500,18 @@ at::Tensor fa_prepare_fwd(at::Tensor qkvi) { ...@@ -587,27 +500,18 @@ at::Tensor fa_prepare_fwd(at::Tensor qkvi) {
std::vector<int64_t> shape = {3, qkvi.size(1), qkvi.size(0), qkvi.size(2), qkvi.size(3)}; std::vector<int64_t> shape = {3, qkvi.size(1), qkvi.size(0), qkvi.size(2), qkvi.size(3)};
at::Tensor qkv = at::empty(shape, at::CUDA(qkvi.scalar_type())); at::Tensor qkv = at::empty(shape, at::CUDA(qkvi.scalar_type()));
size_t warps = qkvi.size(0) * qkvi.size(1); auto te_qkvi = makeTransformerEngineTensor(qkvi);
size_t warps_per_block = flash_attention::block_size / flash_attention::warp_size; auto te_qkv = makeTransformerEngineTensor(qkv);
size_t blocks = (warps + warps_per_block - 1) / warps_per_block;
dim3 grid(blocks, 3); nvte_prepare_flash_attn_fwd(te_qkvi.data(), te_qkv.data(), at::cuda::getCurrentCUDAStream());
int threads = flash_attention::block_size;
if (qkvi.scalar_type() == at::ScalarType::Half) {
using dtype = at::Half;
flash_attention::prepare_kernel_fwd<dtype>
<<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
qkvi.data_ptr<dtype>(), qkv.data_ptr<dtype>(), shape[1], shape[2], shape[3], shape[4]);
} else {
using dtype = at::BFloat16;
flash_attention::prepare_kernel_fwd<dtype>
<<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
qkvi.data_ptr<dtype>(), qkv.data_ptr<dtype>(), shape[1], shape[2], shape[3], shape[4]);
}
return qkv; return qkv;
} }
at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v) { at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
NVTE_CHECK(q.is_contiguous()); NVTE_CHECK(q.is_contiguous());
NVTE_CHECK(k.is_contiguous()); NVTE_CHECK(k.is_contiguous());
NVTE_CHECK(v.is_contiguous()); NVTE_CHECK(v.is_contiguous());
...@@ -618,36 +522,18 @@ at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v) { ...@@ -618,36 +522,18 @@ at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v) {
q.scalar_type() == at::ScalarType::BFloat16); q.scalar_type() == at::ScalarType::BFloat16);
NVTE_CHECK(k.scalar_type() == q.scalar_type()); NVTE_CHECK(k.scalar_type() == q.scalar_type());
NVTE_CHECK(v.scalar_type() == q.scalar_type()); NVTE_CHECK(v.scalar_type() == q.scalar_type());
NVTE_CHECK(q.size(3) % flash_attention::load_size == 0);
NVTE_CHECK(q.size(3) == flash_attention::load_size);
NVTE_CHECK(k.size(3) % flash_attention::load_size == 0);
NVTE_CHECK(k.size(3) == flash_attention::load_size);
NVTE_CHECK(v.size(3) % flash_attention::load_size == 0);
NVTE_CHECK(v.size(3) == flash_attention::load_size);
// 3 x [s, b, n, h] -> [b, s, n, 3 * h] // 3 x [s, b, n, h] -> [b, s, n, 3 * h]
std::vector<int64_t> shape = {q.size(1), q.size(0), q.size(2), 3 * q.size(3)}; std::vector<int64_t> shape = {q.size(1), q.size(0), q.size(2), 3 * q.size(3)};
at::Tensor qkv = at::empty(shape, at::CUDA(q.scalar_type())); at::Tensor qkv = at::empty(shape, at::CUDA(q.scalar_type()));
size_t warps = q.size(0) * q.size(1); auto te_q = makeTransformerEngineTensor(q);
size_t warps_per_block = flash_attention::block_size / flash_attention::warp_size; auto te_k = makeTransformerEngineTensor(k);
size_t blocks = (warps + warps_per_block - 1) / warps_per_block; auto te_v = makeTransformerEngineTensor(v);
dim3 grid(blocks, 3); auto te_qkv = makeTransformerEngineTensor(qkv);
int threads = flash_attention::block_size;
if (q.scalar_type() == at::ScalarType::Half) { nvte_prepare_flash_attn_bwd(te_q.data(), te_k.data(), te_v.data(), te_qkv.data(),
using dtype = at::Half; at::cuda::getCurrentCUDAStream());
flash_attention::prepare_kernel_bwd<dtype>
<<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
q.data_ptr<dtype>(), k.data_ptr<dtype>(), v.data_ptr<dtype>(), qkv.data_ptr<dtype>(),
q.size(0), q.size(1), q.size(2), q.size(3));
} else {
using dtype = at::BFloat16;
flash_attention::prepare_kernel_bwd<dtype>
<<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
q.data_ptr<dtype>(), k.data_ptr<dtype>(), v.data_ptr<dtype>(), qkv.data_ptr<dtype>(),
q.size(0), q.size(1), q.size(2), q.size(3));
}
return qkv; return qkv;
} }
...@@ -658,6 +544,9 @@ at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v) { ...@@ -658,6 +544,9 @@ at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v) {
at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_seqlens, at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_seqlens,
int half_idx) { int half_idx) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
NVTE_CHECK(tensor.dim() == 3 || tensor.dim() == 4); NVTE_CHECK(tensor.dim() == 3 || tensor.dim() == 4);
NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int);
NVTE_CHECK(cu_seqlens.dim() == 1); NVTE_CHECK(cu_seqlens.dim() == 1);
...@@ -683,18 +572,12 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_s ...@@ -683,18 +572,12 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_s
shape[seq_dim] /= 2; shape[seq_dim] /= 2;
at::Tensor half = at::empty(shape, at::CUDA(tensor.scalar_type())); at::Tensor half = at::empty(shape, at::CUDA(tensor.scalar_type()));
// Launch Kernel auto te_tensor = makeTransformerEngineTensor(tensor);
constexpr unsigned int block = 256; auto te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens);
unsigned int grid_x = (tensor.size(seq_dim) / 2 * 32 + block - 1) / block; auto te_half = makeTransformerEngineTensor(half);
unsigned int grid_y = 1;
for (int i = 0; i < seq_dim; i++) { nvte_cp_thd_read_half_tensor(te_tensor.data(), te_cu_seqlens.data(), te_half.data(), half_idx,
grid_y *= tensor.size(i); at::cuda::getCurrentCUDAStream());
}
dim3 grid = {grid_x, grid_y};
transformer_engine::fused_attn::thd_read_half_tensor_kernel<<<
grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
half.data_ptr(), tensor.data_ptr(), cu_seqlens.data_ptr<int>(), batch, hidden_size_in_bytes,
half_idx, tensor.size(seq_dim));
return half; return half;
} }
...@@ -705,6 +588,9 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_s ...@@ -705,6 +588,9 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_s
void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_step, void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_step,
const at::Tensor &cu_seqlens, bool lse_packed) { const at::Tensor &cu_seqlens, bool lse_packed) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
NVTE_CHECK(lse.scalar_type() == at::ScalarType::Float); NVTE_CHECK(lse.scalar_type() == at::ScalarType::Float);
NVTE_CHECK(lse_per_step.scalar_type() == at::ScalarType::Float); NVTE_CHECK(lse_per_step.scalar_type() == at::ScalarType::Float);
NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int);
...@@ -738,26 +624,20 @@ void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_st ...@@ -738,26 +624,20 @@ void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_st
NVTE_CHECK(cu_seqlens.size(0) == batch + 1); NVTE_CHECK(cu_seqlens.size(0) == batch + 1);
} }
constexpr unsigned int block = 256; auto te_lse = makeTransformerEngineTensor(lse);
unsigned int grid_x = (lse_seqlen / 2 + block - 1) / block; auto te_lse_per_step = makeTransformerEngineTensor(lse_per_step);
unsigned int grid_y = num_heads; auto te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens);
dim3 grid = {grid_x, grid_y};
if (lse_packed) { nvte_cp_thd_second_half_lse_correction(te_lse.data(), te_lse_per_step.data(),
transformer_engine::fused_attn::thd_lse_kernel<true, LseCorrectionFunctor> te_cu_seqlens.data(), lse_packed,
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>( at::cuda::getCurrentCUDAStream());
lse.data_ptr<float>(), lse_per_step.data_ptr<float>(), cu_seqlens.data_ptr<int>(),
batch, num_heads, lse_seqlen, second_half_lse_seqlen);
} else {
transformer_engine::fused_attn::thd_lse_kernel<false, LseCorrectionFunctor>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
lse.data_ptr<float>(), lse_per_step.data_ptr<float>(), cu_seqlens.data_ptr<int>(),
batch, num_heads, lse_seqlen, second_half_lse_seqlen);
}
} }
at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_seqlens, at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_seqlens,
bool lse_packed, int second_half_lse_seqlen) { bool lse_packed, int second_half_lse_seqlen) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
NVTE_CHECK(lse.scalar_type() == at::ScalarType::Float); NVTE_CHECK(lse.scalar_type() == at::ScalarType::Float);
NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int);
NVTE_CHECK(cu_seqlens.dim() == 1); NVTE_CHECK(cu_seqlens.dim() == 1);
...@@ -790,22 +670,13 @@ at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_ ...@@ -790,22 +670,13 @@ at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_
at::Tensor half_lse = at::zeros(shape, at::CUDA(lse.scalar_type())); at::Tensor half_lse = at::zeros(shape, at::CUDA(lse.scalar_type()));
constexpr unsigned int block = 256; auto te_lse = makeTransformerEngineTensor(lse);
unsigned int grid_x = (lse_seqlen / 2 + block - 1) / block; auto te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens);
unsigned int grid_y = num_heads; auto te_half_lse = makeTransformerEngineTensor(half_lse);
dim3 grid = {grid_x, grid_y};
if (lse_packed) { nvte_cp_thd_read_second_half_lse(te_lse.data(), te_cu_seqlens.data(), te_half_lse.data(),
transformer_engine::fused_attn::thd_lse_kernel<true, ReadLseFunctor> lse_packed, second_half_lse_seqlen,
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>( at::cuda::getCurrentCUDAStream());
lse.data_ptr<float>(), half_lse.data_ptr<float>(), cu_seqlens.data_ptr<int>(), batch,
num_heads, lse_seqlen, second_half_lse_seqlen);
} else {
transformer_engine::fused_attn::thd_lse_kernel<false, ReadLseFunctor>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
lse.data_ptr<float>(), half_lse.data_ptr<float>(), cu_seqlens.data_ptr<int>(), batch,
num_heads, lse_seqlen, second_half_lse_seqlen);
}
return half_lse; return half_lse;
} }
...@@ -814,194 +685,38 @@ at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_ ...@@ -814,194 +685,38 @@ at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_
* Support THD format for Context Parallel: Out correction in forward * Support THD format for Context Parallel: Out correction in forward
**************************************************************************************************/ **************************************************************************************************/
template <typename dtype, int only_second_half>
static void thd_out_correction_helper(at::Tensor out, const at::Tensor &out_per_step,
const at::Tensor &lse, const at::Tensor &lse_per_step,
const at::Tensor &cu_seqlens, bool lse_packed) {
NVTE_CHECK(out.scalar_type() == out_per_step.scalar_type());
NVTE_CHECK(lse.scalar_type() == at::ScalarType::Float);
NVTE_CHECK(lse_per_step.scalar_type() == at::ScalarType::Float);
NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int);
int total_tokens = out.size(0);
int num_heads = out.size(1);
int dim_per_head = out.size(2);
NVTE_CHECK(out_per_step.size(0) == total_tokens / (only_second_half + 1));
NVTE_CHECK(out_per_step.size(1) == num_heads);
NVTE_CHECK(out_per_step.size(2) == dim_per_head);
int batch, lse_seqlen, lse_per_step_seqlen;
if (lse_packed) {
batch = cu_seqlens.size(0) - 1;
lse_seqlen = lse.size(1);
lse_per_step_seqlen = lse_per_step.size(1);
NVTE_CHECK(lse.size(0) == num_heads);
NVTE_CHECK(lse_seqlen >= total_tokens);
NVTE_CHECK(lse_per_step.size(0) == num_heads);
NVTE_CHECK(lse_per_step_seqlen >= lse_seqlen / (only_second_half + 1));
} else {
batch = lse.size(0);
lse_seqlen = lse.size(2);
lse_per_step_seqlen = lse_per_step.size(2);
NVTE_CHECK(lse.size(1) == num_heads);
NVTE_CHECK(lse_per_step.size(0) == batch);
NVTE_CHECK(lse_per_step.size(1) == num_heads);
NVTE_CHECK(lse_per_step_seqlen == lse_seqlen / (only_second_half + 1));
NVTE_CHECK(cu_seqlens.size(0) == batch + 1);
}
constexpr int tile = 16;
constexpr int block = 512;
unsigned int grid_x =
(static_cast<size_t>(total_tokens) / (only_second_half + 1) * tile + block - 1) / block;
dim3 grid = {grid_x, (unsigned int)num_heads};
if (lse_packed) {
transformer_engine::fused_attn::thd_out_correction_kernel<dtype, only_second_half, tile, true>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
out.data_ptr<dtype>(), out_per_step.data_ptr<dtype>(), lse.data_ptr<float>(),
lse_per_step.data_ptr<float>(), cu_seqlens.data_ptr<int>(), batch, num_heads,
dim_per_head, lse_seqlen, lse_per_step_seqlen);
} else {
transformer_engine::fused_attn::thd_out_correction_kernel<dtype, only_second_half, tile, false>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
out.data_ptr<dtype>(), out_per_step.data_ptr<dtype>(), lse.data_ptr<float>(),
lse_per_step.data_ptr<float>(), cu_seqlens.data_ptr<int>(), batch, num_heads,
dim_per_head, lse_seqlen, lse_per_step_seqlen);
}
}
void thd_out_correction(at::Tensor out, const at::Tensor &out_per_step, const at::Tensor &lse, void thd_out_correction(at::Tensor out, const at::Tensor &out_per_step, const at::Tensor &lse,
const at::Tensor &lse_per_step, const at::Tensor &cu_seqlens, const at::Tensor &lse_per_step, const at::Tensor &cu_seqlens,
bool only_second_half, bool lse_packed) { bool only_second_half, bool lse_packed) {
if (only_second_half) { using namespace transformer_engine;
if (out.scalar_type() == at::ScalarType::Half) { using namespace transformer_engine::pytorch;
using dtype = at::Half;
thd_out_correction_helper<dtype, 1>(out, out_per_step, lse, lse_per_step, cu_seqlens, auto te_out = makeTransformerEngineTensor(out);
lse_packed); auto te_out_per_step = makeTransformerEngineTensor(out_per_step);
} else if (out.scalar_type() == at::ScalarType::BFloat16) { auto te_lse = makeTransformerEngineTensor(lse);
using dtype = at::BFloat16; auto te_lse_per_step = makeTransformerEngineTensor(lse_per_step);
thd_out_correction_helper<dtype, 1>(out, out_per_step, lse, lse_per_step, cu_seqlens, auto te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens);
lse_packed); nvte_cp_thd_out_correction(te_out.data(), te_out_per_step.data(), te_lse.data(),
} else if (out.scalar_type() == at::ScalarType::Float) { te_lse_per_step.data(), te_cu_seqlens.data(), only_second_half,
using dtype = float; lse_packed, at::cuda::getCurrentCUDAStream());
thd_out_correction_helper<dtype, 1>(out, out_per_step, lse, lse_per_step, cu_seqlens,
lse_packed);
} else {
NVTE_ERROR("Unsupported dtype of out\n");
}
} else {
if (out.scalar_type() == at::ScalarType::Half) {
using dtype = at::Half;
thd_out_correction_helper<dtype, 0>(out, out_per_step, lse, lse_per_step, cu_seqlens,
lse_packed);
} else if (out.scalar_type() == at::ScalarType::BFloat16) {
using dtype = at::BFloat16;
thd_out_correction_helper<dtype, 0>(out, out_per_step, lse, lse_per_step, cu_seqlens,
lse_packed);
} else if (out.scalar_type() == at::ScalarType::Float) {
using dtype = float;
thd_out_correction_helper<dtype, 0>(out, out_per_step, lse, lse_per_step, cu_seqlens,
lse_packed);
} else {
NVTE_ERROR("Unsupported dtype of out\n");
}
}
} }
/*************************************************************************************************** /***************************************************************************************************
* Support THD format for Context Parallel: Gradients correction in backward * Support THD format for Context Parallel: Gradients correction in backward
**************************************************************************************************/ **************************************************************************************************/
template <typename dtype, typename Functor_0, typename Functor_1, int functor_idx>
static void thd_grad_correction_helper(at::Tensor grad, const at::Tensor &grad_per_step,
const at::Tensor &cu_seqlens) {
NVTE_CHECK(grad.dim() == 3 || grad.dim() == 4);
NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int);
NVTE_CHECK(cu_seqlens.dim() == 1);
// Shape of dq is [t, h, d], so the dimension of "t" is 0
// Shape of dkv is [2, t, h, d], so the dimension of "t" is 1
int seq_dim = grad.dim() == 3 ? 0 : 1;
int total_tokens = grad.size(seq_dim);
int num_heads = grad.size(seq_dim + 1);
int dim_per_head = grad.size(seq_dim + 2);
int batch = cu_seqlens.size(0) - 1;
if constexpr (functor_idx < 2) {
NVTE_CHECK(grad_per_step.size(seq_dim) == total_tokens / 2);
} else {
NVTE_CHECK(grad_per_step.size(seq_dim) == total_tokens);
}
NVTE_CHECK(grad_per_step.size(seq_dim + 1) == num_heads);
NVTE_CHECK(grad_per_step.size(seq_dim + 2) == dim_per_head);
size_t hidden_size = num_heads * dim_per_head;
NVTE_CHECK((hidden_size * c10::elementSize(grad.scalar_type())) % 16 == 0);
constexpr unsigned int block = 256;
unsigned int grid_x;
if constexpr (functor_idx < 2) {
grid_x = (total_tokens / 2 * 32 + block - 1) / block;
} else {
grid_x = (total_tokens * 32 + block - 1) / block;
}
unsigned int grid_y = 1;
for (int i = 0; i < seq_dim; i++) {
grid_y *= grad.size(i);
}
dim3 grid = {grid_x, grid_y};
transformer_engine::fused_attn::thd_grad_correction_kernel<dtype, Functor_0, Functor_1,
functor_idx, 32>
<<<grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>(
grad.data_ptr<dtype>(), grad_per_step.data_ptr<dtype>(), cu_seqlens.data_ptr<int>(),
batch, hidden_size, total_tokens);
}
template <typename dtype>
static void thd_grad_dispatcher(at::Tensor grad, const at::Tensor &grad_per_step,
const at::Tensor &cu_seqlens, const std::string &first_half,
const std::string &second_half) {
if (first_half == "add" && second_half == "none") {
thd_grad_correction_helper<dtype, AddFunctor<dtype>, EmptyFunctor, 0>(grad, grad_per_step,
cu_seqlens);
} else if (first_half == "copy" && second_half == "none") {
thd_grad_correction_helper<dtype, CopyFunctor, EmptyFunctor, 0>(grad, grad_per_step,
cu_seqlens);
} else if (first_half == "none" && second_half == "add") {
thd_grad_correction_helper<dtype, EmptyFunctor, AddFunctor<dtype>, 1>(grad, grad_per_step,
cu_seqlens);
} else if (first_half == "none" && second_half == "copy") {
thd_grad_correction_helper<dtype, EmptyFunctor, CopyFunctor, 1>(grad, grad_per_step,
cu_seqlens);
} else if (first_half == "add" && second_half == "copy") {
thd_grad_correction_helper<dtype, AddFunctor<dtype>, CopyFunctor, 2>(grad, grad_per_step,
cu_seqlens);
} else if (first_half == "copy" && second_half == "add") {
thd_grad_correction_helper<dtype, CopyFunctor, AddFunctor<dtype>, 2>(grad, grad_per_step,
cu_seqlens);
} else {
NVTE_ERROR("Unsupported Functor of first half and second_half\n");
}
}
void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step, void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step,
const at::Tensor &cu_seqlens, const std::string &first_half, const at::Tensor &cu_seqlens, const std::string &first_half,
const std::string &second_half) { const std::string &second_half) {
if (grad.scalar_type() == at::ScalarType::Half) { using namespace transformer_engine;
thd_grad_dispatcher<at::Half>(grad, grad_per_step, cu_seqlens, first_half, second_half); using namespace transformer_engine::pytorch;
} else if (grad.scalar_type() == at::ScalarType::BFloat16) {
thd_grad_dispatcher<at::BFloat16>(grad, grad_per_step, cu_seqlens, first_half, second_half); auto te_grad = makeTransformerEngineTensor(grad);
} else if (grad.scalar_type() == at::ScalarType::Float) { auto te_grad_per_step = makeTransformerEngineTensor(grad_per_step);
thd_grad_dispatcher<float>(grad, grad_per_step, cu_seqlens, first_half, second_half); auto te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens);
} else { nvte_cp_thd_grad_correction(te_grad.data(), te_grad_per_step.data(), te_cu_seqlens.data(),
NVTE_ERROR("Unsupported dtype of grad\n"); first_half.data(), second_half.data(),
} at::cuda::getCurrentCUDAStream());
} }
/*************************************************************************************************** /***************************************************************************************************
...@@ -1010,6 +725,9 @@ void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step, ...@@ -1010,6 +725,9 @@ void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step,
at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_tokens, at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_tokens,
int world_size, int rank) { int world_size, int rank) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int);
NVTE_CHECK(cu_seqlens.dim() == 1); NVTE_CHECK(cu_seqlens.dim() == 1);
NVTE_CHECK(cu_seqlens.size(0) >= 2); NVTE_CHECK(cu_seqlens.size(0) >= 2);
...@@ -1022,11 +740,11 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t ...@@ -1022,11 +740,11 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t
std::vector<int64_t> shape = {total_tokens / world_size}; std::vector<int64_t> shape = {total_tokens / world_size};
at::Tensor output = at::empty(shape, at::CUDA(at::ScalarType::Int)); at::Tensor output = at::empty(shape, at::CUDA(at::ScalarType::Int));
constexpr unsigned int block = 256; auto te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens);
unsigned int grid = (output.size(0) + block - 1) / block; auto te_output = makeTransformerEngineTensor(output);
transformer_engine::fused_attn::thd_partition_indices_kernel<<<
grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>( nvte_cp_thd_get_partitioned_indices(te_cu_seqlens.data(), te_output.data(), total_tokens,
output.data_ptr<int>(), cu_seqlens.data_ptr<int>(), batch, total_tokens, world_size, rank); world_size, rank, at::cuda::getCurrentCUDAStream());
return output; return output;
} }
...@@ -1035,39 +753,22 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t ...@@ -1035,39 +753,22 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t
* KV Cache: Convert a tensor from qkv_format = thd to qkv_format = bshd * KV Cache: Convert a tensor from qkv_format = thd to qkv_format = bshd
**************************************************************************************************/ **************************************************************************************************/
template <typename scalar_t>
void convert_thd_to_bshd_launcher(at::Tensor tensor, at::Tensor new_tensor, at::Tensor cu_seqlens,
int b, int max_seq_len, int h, int d) {
transformer_engine::fused_attn::
convert_thd_to_bshd_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<scalar_t *>(tensor.data_ptr<scalar_t>()),
reinterpret_cast<scalar_t *>(new_tensor.data_ptr<scalar_t>()), cu_seqlens.data_ptr<int>(),
b, max_seq_len, h, d);
}
at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len) { at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
int h = tensor.size(1); int h = tensor.size(1);
int d = tensor.size(2); int d = tensor.size(2);
std::vector<int64_t> shape = {b, max_seq_len, h, d}; std::vector<int64_t> shape = {b, max_seq_len, h, d};
at::Tensor new_tensor = at::zeros(shape, at::CUDA(tensor.scalar_type())); at::Tensor new_tensor = at::zeros(shape, at::CUDA(tensor.scalar_type()));
if (new_tensor.scalar_type() == at::ScalarType::Half) {
using dtype = at::Half; auto te_tensor = makeTransformerEngineTensor(tensor);
convert_thd_to_bshd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d); auto te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens);
} else if (new_tensor.scalar_type() == at::ScalarType::BFloat16) { auto te_new_tensor = makeTransformerEngineTensor(new_tensor);
using dtype = at::BFloat16;
convert_thd_to_bshd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d); nvte_convert_thd_to_bshd(te_tensor.data(), te_cu_seqlens.data(), te_new_tensor.data(), b,
} else if (new_tensor.scalar_type() == at::ScalarType::Float) { max_seq_len, at::cuda::getCurrentCUDAStream());
using dtype = float;
convert_thd_to_bshd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d);
} else if (new_tensor.scalar_type() == at::ScalarType::Float8_e4m3fn) {
using dtype = at::Float8_e4m3fn;
convert_thd_to_bshd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d);
} else if (new_tensor.scalar_type() == at::ScalarType::Float8_e5m2) {
using dtype = at::Float8_e5m2;
convert_thd_to_bshd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d);
} else {
NVTE_ERROR("Unsupported dtype for KV cache.\n");
}
return new_tensor; return new_tensor;
} }
...@@ -1075,95 +776,33 @@ at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, ...@@ -1075,95 +776,33 @@ at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b,
* KV Cache: Convert a tensor from qkv_format = bshd to qkv_format = thd * KV Cache: Convert a tensor from qkv_format = bshd to qkv_format = thd
**************************************************************************************************/ **************************************************************************************************/
template <typename scalar_t>
void convert_bshd_to_thd_launcher(at::Tensor tensor, at::Tensor new_tensor, at::Tensor cu_seqlens,
int b, int max_seq_len, int h, int d) {
transformer_engine::fused_attn::
convert_bshd_to_thd_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<scalar_t *>(tensor.data_ptr<scalar_t>()),
reinterpret_cast<scalar_t *>(new_tensor.data_ptr<scalar_t>()), cu_seqlens.data_ptr<int>(),
b, max_seq_len, h, d);
}
at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int t) { at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int t) {
int b = tensor.size(0); using namespace transformer_engine;
using namespace transformer_engine::pytorch;
int max_seq_len = tensor.size(1); int max_seq_len = tensor.size(1);
int h = tensor.size(2); int h = tensor.size(2);
int d = tensor.size(3); int d = tensor.size(3);
std::vector<int64_t> shape = {t, h, d}; std::vector<int64_t> shape = {t, h, d};
at::Tensor new_tensor = at::zeros(shape, at::CUDA(tensor.scalar_type())); at::Tensor new_tensor = at::zeros(shape, at::CUDA(tensor.scalar_type()));
if (tensor.scalar_type() == at::ScalarType::Half) {
using dtype = at::Half;
convert_bshd_to_thd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d);
} else if (tensor.scalar_type() == at::ScalarType::BFloat16) {
using dtype = at::BFloat16;
convert_bshd_to_thd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d);
} else if (tensor.scalar_type() == at::ScalarType::Float) {
using dtype = float;
convert_bshd_to_thd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d);
} else if (tensor.scalar_type() == at::ScalarType::Float8_e4m3fn) {
using dtype = at::Float8_e4m3fn;
convert_bshd_to_thd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d);
} else if (tensor.scalar_type() == at::ScalarType::Float8_e5m2) {
using dtype = at::Float8_e5m2;
convert_bshd_to_thd_launcher<dtype>(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d);
} else {
NVTE_ERROR("Unsupported dtype for KV cache.\n");
}
return new_tensor;
}
/*************************************************************************************************** auto te_tensor = makeTransformerEngineTensor(tensor);
* KV Cache: Copy new KV tokens to the KV cache auto te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens);
* 1. new_k and new_v are in qkv_format; k_cache and v_cache are in 'bshd' format auto te_new_tensor = makeTransformerEngineTensor(new_tensor);
* 2. cu_new_lens and cu_cached_lens are in shape [b + 1]; cu_cached_lens include the added lens
* in current step
* 3. Non-paged KV cache is a special case of paged KV cache, with page_table = [b, 1] and
* max_pages_per_seq = 1. We use the same underlying kernel for both non-paged and paged.
* Set is_non_paged = True/False to indicate as such.
* 4. is_non_paged = True also re-indexes the KV cache, e.g. the initial batch indices [0, 3, 1, 2]
* becomes [0, 1, 1, 2]. The page_table = batch_indices.unsqueeze(1) is however unchanged.
* batch_indices_post can be used for monotonical indexing, i.e. [0, 1, 2, 3]. batch_indices is
* preserved for the next layer in the same iteration.
* 5. Only supports same page_table for k_cache and v_cache
* 6. Only pad_between_seqs = False when qkv_format = thd, i.e. there should be no pad tokens
* between sequences in new_k and new_v such as [a a a 0..0 b b 0..0 c 0..0].
**************************************************************************************************/
template <typename scalar_t> nvte_convert_bshd_to_thd(te_tensor.data(), te_cu_seqlens.data(), te_new_tensor.data(), t,
void copy_to_kv_cache_launcher(at::Tensor new_k, at::Tensor new_v, at::Tensor k_cache, at::cuda::getCurrentCUDAStream());
at::Tensor v_cache, at::Tensor page_table, at::Tensor cu_new_lens,
at::Tensor cu_cached_lens, NVTE_QKV_Format qkv_format, int h_kv, return new_tensor;
int d_k, int d_v, int b, int max_ctx_len, int max_seq_len,
int max_pages_per_seq, bool is_non_paged) {
if (new_k.data_ptr() != nullptr && new_v.data_ptr() != nullptr && k_cache.data_ptr() != nullptr &&
v_cache.data_ptr() != nullptr) {
if (is_non_paged) {
transformer_engine::fused_attn::
reindex_kv_cache_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<scalar_t *>(k_cache.data_ptr<scalar_t>()),
reinterpret_cast<scalar_t *>(v_cache.data_ptr<scalar_t>()),
page_table.data_ptr<int>(), cu_new_lens.data_ptr<int>(),
cu_cached_lens.data_ptr<int>(), h_kv, d_k, d_v, b, max_seq_len);
}
transformer_engine::fused_attn::
copy_to_kv_cache_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<scalar_t *>(new_k.data_ptr<scalar_t>()),
reinterpret_cast<scalar_t *>(new_v.data_ptr<scalar_t>()),
reinterpret_cast<scalar_t *>(k_cache.data_ptr<scalar_t>()),
reinterpret_cast<scalar_t *>(v_cache.data_ptr<scalar_t>()), page_table.data_ptr<int>(),
cu_new_lens.data_ptr<int>(), cu_cached_lens.data_ptr<int>(), qkv_format, h_kv, d_k, d_v,
b, max_ctx_len, max_seq_len, max_pages_per_seq, is_non_paged);
}
} }
void copy_to_kv_cache(at::Tensor new_k, at::Tensor new_v, at::Tensor k_cache, at::Tensor v_cache, void copy_to_kv_cache(at::Tensor new_k, at::Tensor new_v, at::Tensor k_cache, at::Tensor v_cache,
at::Tensor page_table, at::Tensor cu_new_lens, at::Tensor cu_cached_lens, at::Tensor page_table, at::Tensor cu_new_lens, at::Tensor cu_cached_lens,
NVTE_QKV_Format qkv_format, int b, int max_ctx_len, int max_seq_len, NVTE_QKV_Format qkv_format, int b, int max_ctx_len, int max_seq_len,
int max_pages_per_seq, bool is_non_paged) { int max_pages_per_seq, bool is_non_paged) {
int h_kv = new_k.size(-2); using namespace transformer_engine;
int d_k = new_k.size(-1); using namespace transformer_engine::pytorch;
int d_v = new_v.size(-1);
NVTE_CHECK(k_cache.scalar_type() == v_cache.scalar_type() && NVTE_CHECK(k_cache.scalar_type() == v_cache.scalar_type() &&
new_k.scalar_type() == new_v.scalar_type() && new_k.scalar_type() == new_v.scalar_type() &&
new_k.scalar_type() == k_cache.scalar_type(), new_k.scalar_type() == k_cache.scalar_type(),
...@@ -1171,33 +810,17 @@ void copy_to_kv_cache(at::Tensor new_k, at::Tensor new_v, at::Tensor k_cache, at ...@@ -1171,33 +810,17 @@ void copy_to_kv_cache(at::Tensor new_k, at::Tensor new_v, at::Tensor k_cache, at
NVTE_CHECK(qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD || NVTE_CHECK(qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD ||
qkv_format == NVTE_QKV_Format::NVTE_THD, qkv_format == NVTE_QKV_Format::NVTE_THD,
"qkv_format must be {BSHD, SBHD, THD}."); "qkv_format must be {BSHD, SBHD, THD}.");
if (k_cache.scalar_type() == at::ScalarType::Half) {
using dtype = at::Half; auto te_new_k = makeTransformerEngineTensor(new_k);
copy_to_kv_cache_launcher<dtype>(new_k, new_v, k_cache, v_cache, page_table, cu_new_lens, auto te_new_v = makeTransformerEngineTensor(new_v);
cu_cached_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len, auto te_k_cache = makeTransformerEngineTensor(k_cache);
max_seq_len, max_pages_per_seq, is_non_paged); auto te_v_cache = makeTransformerEngineTensor(v_cache);
auto te_page_table = makeTransformerEngineTensor(page_table);
} else if (k_cache.scalar_type() == at::ScalarType::BFloat16) { auto te_cu_new_lens = makeTransformerEngineTensor(cu_new_lens);
using dtype = at::BFloat16; auto te_cu_cached_lens = makeTransformerEngineTensor(cu_cached_lens);
copy_to_kv_cache_launcher<dtype>(new_k, new_v, k_cache, v_cache, page_table, cu_new_lens,
cu_cached_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len, nvte_copy_to_kv_cache(te_new_k.data(), te_new_v.data(), te_k_cache.data(), te_v_cache.data(),
max_seq_len, max_pages_per_seq, is_non_paged); te_page_table.data(), te_cu_new_lens.data(), te_cu_cached_lens.data(),
} else if (k_cache.scalar_type() == at::ScalarType::Float) { qkv_format, b, max_ctx_len, max_seq_len, max_pages_per_seq, is_non_paged,
using dtype = float; at::cuda::getCurrentCUDAStream());
copy_to_kv_cache_launcher<dtype>(new_k, new_v, k_cache, v_cache, page_table, cu_new_lens,
cu_cached_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len,
max_seq_len, max_pages_per_seq, is_non_paged);
} else if (k_cache.scalar_type() == at::ScalarType::Float8_e4m3fn) {
using dtype = at::Float8_e4m3fn;
copy_to_kv_cache_launcher<dtype>(new_k, new_v, k_cache, v_cache, page_table, cu_new_lens,
cu_cached_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len,
max_seq_len, max_pages_per_seq, is_non_paged);
} else if (k_cache.scalar_type() == at::ScalarType::Float8_e5m2) {
using dtype = at::Float8_e5m2;
copy_to_kv_cache_launcher<dtype>(new_k, new_v, k_cache, v_cache, page_table, cu_new_lens,
cu_cached_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len,
max_seq_len, max_pages_per_seq, is_non_paged);
} else {
NVTE_ERROR("Unsupported dtype for KV cache.\n");
}
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment