Unverified Commit 7b18f235 authored by Michael Goldfarb's avatar Michael Goldfarb Committed by GitHub
Browse files

Fused Attention Support 64-bit Ragged Offsets for Large THD Tensors (#1230)



* Use 64-bit offsets for cuDNN 9.5+
* Align workspace tensors to 16B.
* Fix bug where std::accumulate overflowed on large tensor shapes.
* Only support 64-bit offsets on arbitrary sequence length fp16 backend.
Signed-off-by: default avatarMichael Goldfarb <mgoldfarb@nvidia.com>
parent 29e3a090
......@@ -56,6 +56,7 @@ constexpr T DIVUP(const T &x, const T &y) {
using byte = uint8_t;
using int32 = int32_t;
using int64 = int64_t;
using fp32 = float;
using fp16 = half;
using bf16 = nv_bfloat16;
......@@ -73,6 +74,7 @@ constexpr inline const char *type_name() noexcept;
}
TRANSFORMER_ENGINE_TYPE_NAME(uint8_t)
TRANSFORMER_ENGINE_TYPE_NAME(int32_t)
TRANSFORMER_ENGINE_TYPE_NAME(int64_t)
TRANSFORMER_ENGINE_TYPE_NAME(float)
TRANSFORMER_ENGINE_TYPE_NAME(half)
TRANSFORMER_ENGINE_TYPE_NAME(nv_bfloat16)
......@@ -84,7 +86,7 @@ TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e5m2)
template <typename T>
struct TypeInfo {
using types = std::tuple<byte, int32, fp32, fp16, bf16, fp8e4m3, fp8e5m2>;
using types = std::tuple<byte, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2>;
template <typename U, DType current>
struct Helper {
......@@ -121,7 +123,11 @@ struct TypeInfo {
{ __VA_ARGS__ } \
} break; \
case DType::kInt32: { \
using type = float; \
using type = int32_t; \
{ __VA_ARGS__ } \
} break; \
case DType::kInt64: { \
using type = int64_t; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat32: { \
......@@ -246,6 +252,14 @@ inline int log2_ceil(int value) {
return log2_value;
}
template <size_t B>
inline size_t alignTo(size_t x) {
size_t r = x % B;
if (r == 0) return x;
return x + B - r;
}
template <typename T>
struct is_fp8 : std::false_type {};
......
......@@ -80,7 +80,18 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
const int sm_arch_ = cuda::sm_arch(device_id);
NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type.");
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
auto cudnn_runtime_version = cudnnGetVersion();
// For ragged offsets we only support 32-bit prior to cuDNN 9.5
// Only used when THD format is requested.
const bool requires_64bit_ragged_offset =
(qkv_format == NVTE_THD && fused_attn::get_ragged_offset_dtype(
layout_group, num_attn_heads, num_gqa_groups, max_seqlen_q,
max_seqlen_kv, head_dim_qk, head_dim_v) == DType::kInt64);
const bool supported_ragged_offset_size =
(!requires_64bit_ragged_offset || cudnn_runtime_version >= 90500);
if (((q_dtype == NVTEDType::kNVTEFloat8E4M3) || (q_dtype == NVTEDType::kNVTEFloat8E5M2)) &&
(sm_arch_ >= 90) && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) &&
(((cudnn_runtime_version >= 8900) && (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) &&
......@@ -91,7 +102,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
((qkv_format == NVTE_QKV_Format::NVTE_BSHD) ||
(qkv_format == NVTE_QKV_Format::NVTE_SBHD)) &&
((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) ||
(attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK))))) {
(attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)))) &&
!requires_64bit_ragged_offset) {
if (cudnn_runtime_version >= 8900) {
backend = NVTE_Fused_Attn_Backend::NVTE_FP8;
} else {
......@@ -118,7 +130,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) ||
(qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) ||
(qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD)) &&
((window_size_left == -1) && (window_size_right == -1 || window_size_right == 0))) {
((window_size_left == -1) && (window_size_right == -1 || window_size_right == 0)) &&
!requires_64bit_ragged_offset) {
flag_m512 = true;
}
if ( // architecture
......@@ -183,7 +196,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
max_seqlen_q == max_seqlen_kv)) &&
dropout == 0.0 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS &&
(qkv_format == NVTE_QKV_Format::NVTE_BSHD ||
qkv_format == NVTE_QKV_Format::NVTE_SBHD)))))) {
qkv_format == NVTE_QKV_Format::NVTE_SBHD))))) &&
// check 64-bit ragged offset support
(supported_ragged_offset_size)) {
flag_arb = true;
}
if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) {
......
......@@ -58,6 +58,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size,
cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI);
bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) ||
......@@ -75,7 +76,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
if (is_ragged) {
NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!");
}
auto cudnn_runtime_version = cudnnGetVersion();
const auto cudnn_runtime_version = cudnnGetVersion();
const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32;
try {
FADescriptor_v1 descriptor{b,
......@@ -145,22 +147,22 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
.set_name("offset_q")
.set_dim({b + 1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
.set_data_type(get_cudnn_fe_dtype(ragged_offset_type)));
offset_k = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_k")
.set_dim({b + 1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
.set_data_type(get_cudnn_fe_dtype(ragged_offset_type)));
offset_v = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_v")
.set_dim({b + 1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
.set_data_type(get_cudnn_fe_dtype(ragged_offset_type)));
offset_o = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_o")
.set_dim({b + 1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
.set_data_type(get_cudnn_fe_dtype(ragged_offset_type)));
std::vector<int64_t> q_stride(4);
std::vector<int64_t> k_stride(4);
......@@ -311,10 +313,15 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
offset_v, offset_o, dropout_seed, dropout_offset] =
get_graph(sdpa_f16_fprop_cache, descriptor);
auto plan_workspace_size = mha_graph->get_workspace_size();
// Exit to request upper level API to allocate memory if needed
size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t);
size_t seqlen_offsets_workspace_size = 4 * (b + 1) * sizeof(int32_t);
// n.b. Care should be taken to align each of the added worksapce tensors to their type.
// We do this by adding padding at the end of each separate allocation.
auto plan_workspace_size = alignTo<16>(mha_graph->get_workspace_size());
const size_t num_bytes_per_seqlen = alignTo<16>(b * sizeof(int32_t));
const size_t actual_seqlen_workspace_size = 2 * num_bytes_per_seqlen;
const size_t num_bytes_per_ragged_offset =
alignTo<16>((b + 1) * typeToSize(ragged_offset_type));
const size_t seqlen_offsets_workspace_size = 4 * num_bytes_per_ragged_offset;
if (workspace == nullptr) {
*workspace_size =
plan_workspace_size + actual_seqlen_workspace_size + seqlen_offsets_workspace_size;
......@@ -339,7 +346,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
constexpr size_t nthreads_per_block = 128;
const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block;
void *devActualSeqlenQ = static_cast<int8_t *>(workspace) + plan_workspace_size;
void *devActualSeqlenKV = static_cast<int8_t *>(devActualSeqlenQ) + b * sizeof(int32_t);
void *devActualSeqlenKV = static_cast<int8_t *>(devActualSeqlenQ) + num_bytes_per_seqlen;
cu_seqlens_to_actual_seqlens<<<grid, nthreads_per_block, 0, stream>>>(
b, static_cast<const int32_t *>(devPtrCuSeqlensQ),
static_cast<const int32_t *>(devPtrCuSeqlensKV), static_cast<int32_t *>(devActualSeqlenQ),
......@@ -353,15 +360,14 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
const size_t grid = (b + nthreads_per_block) / nthreads_per_block;
void *devOffsetsQ =
static_cast<int8_t *>(workspace) + plan_workspace_size + actual_seqlen_workspace_size;
void *devOffsetsK = static_cast<int8_t *>(devOffsetsQ) + (b + 1) * sizeof(int32_t);
void *devOffsetsV = static_cast<int8_t *>(devOffsetsK) + (b + 1) * sizeof(int32_t);
void *devOffsetsO = static_cast<int8_t *>(devOffsetsV) + (b + 1) * sizeof(int32_t);
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout);
void *devOffsetsK = static_cast<int8_t *>(devOffsetsQ) + num_bytes_per_ragged_offset;
void *devOffsetsV = static_cast<int8_t *>(devOffsetsK) + num_bytes_per_ragged_offset;
void *devOffsetsO = static_cast<int8_t *>(devOffsetsV) + num_bytes_per_ragged_offset;
const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout);
cu_seqlens_padded_to_offsets<<<grid, nthreads_per_block, 0, stream>>>(
layout_group, b, h, hg, d_qk, d_v, static_cast<int32_t *>(devPtrSeqOffsetsQ),
static_cast<int32_t *>(devPtrSeqOffsetsKV), static_cast<int32_t *>(devOffsetsQ),
static_cast<int32_t *>(devOffsetsK), static_cast<int32_t *>(devOffsetsV),
static_cast<int32_t *>(devOffsetsO));
static_cast<int32_t *>(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK,
devOffsetsV, devOffsetsO);
variant_pack[offset_q] = devOffsetsQ;
variant_pack[offset_k] = devOffsetsK;
variant_pack[offset_v] = devOffsetsV;
......@@ -390,6 +396,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size,
cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI);
bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) ||
......@@ -404,10 +411,14 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
(mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK));
bool is_dropout = (dropout_probability != 0.0f);
bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD);
auto cudnn_runtime_version = cudnnGetVersion();
const auto cudnn_runtime_version = cudnnGetVersion();
const int device_id = cuda::current_device();
const int sm_arch_ = cuda::sm_arch(device_id);
// We choose between 32-bit and 64-bit offsets depending on need.
// This allows us to support older cuDNN runtimes gracefully.
const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32;
try {
FADescriptor_v1 descriptor{b,
h,
......@@ -481,22 +492,22 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
.set_name("offset_q")
.set_dim({b + 1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
.set_data_type(get_cudnn_fe_dtype(ragged_offset_type)));
offset_k = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_k")
.set_dim({b + 1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
.set_data_type(get_cudnn_fe_dtype(ragged_offset_type)));
offset_v = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_v")
.set_dim({b + 1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
.set_data_type(get_cudnn_fe_dtype(ragged_offset_type)));
offset_o = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("offset_o")
.set_dim({b + 1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
.set_data_type(get_cudnn_fe_dtype(ragged_offset_type)));
std::vector<int64_t> q_stride(4);
std::vector<int64_t> k_stride(4);
std::vector<int64_t> v_stride(4);
......@@ -693,11 +704,15 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
offset_q, offset_k, offset_v, offset_o, dropout_seed, dropout_offset] =
get_graph(sdpa_f16_bprop_cache, descriptor);
auto plan_workspace_size = mha_graph->get_workspace_size();
// Exit to request upper level API to allocate memory if needed
size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t);
size_t seqlen_offsets_workspace_size = 4 * (b + 1) * sizeof(int32_t);
// n.b. Care should be taken to align each of the added worksapce tensors to their type.
// We do this by adding padding at the end of each separate allocation.
auto plan_workspace_size = alignTo<16>(mha_graph->get_workspace_size());
const size_t num_bytes_per_seqlen = alignTo<16>(b * sizeof(int32_t));
const size_t actual_seqlen_workspace_size = 2 * num_bytes_per_seqlen;
const size_t num_bytes_per_ragged_offset =
alignTo<16>((b + 1) * typeToSize(ragged_offset_type));
const size_t seqlen_offsets_workspace_size = 4 * num_bytes_per_ragged_offset;
if (workspace == nullptr) {
*workspace_size =
plan_workspace_size + actual_seqlen_workspace_size + seqlen_offsets_workspace_size;
......@@ -735,7 +750,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
constexpr size_t nthreads_per_block = 128;
const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block;
void *devActualSeqlenQ = static_cast<int8_t *>(workspace) + plan_workspace_size;
void *devActualSeqlenKV = static_cast<int8_t *>(devActualSeqlenQ) + b * sizeof(int32_t);
void *devActualSeqlenKV = static_cast<int8_t *>(devActualSeqlenQ) + num_bytes_per_seqlen;
cu_seqlens_to_actual_seqlens<<<grid, nthreads_per_block, 0, stream>>>(
b, static_cast<const int32_t *>(devPtrCuSeqlensQ),
static_cast<const int32_t *>(devPtrCuSeqlensKV), static_cast<int32_t *>(devActualSeqlenQ),
......@@ -749,15 +764,14 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
const size_t grid = (b + nthreads_per_block) / nthreads_per_block;
void *devOffsetsQ =
static_cast<int8_t *>(workspace) + plan_workspace_size + actual_seqlen_workspace_size;
void *devOffsetsK = static_cast<int8_t *>(devOffsetsQ) + (b + 1) * sizeof(int32_t);
void *devOffsetsV = static_cast<int8_t *>(devOffsetsK) + (b + 1) * sizeof(int32_t);
void *devOffsetsO = static_cast<int8_t *>(devOffsetsV) + (b + 1) * sizeof(int32_t);
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout);
void *devOffsetsK = static_cast<int8_t *>(devOffsetsQ) + num_bytes_per_ragged_offset;
void *devOffsetsV = static_cast<int8_t *>(devOffsetsK) + num_bytes_per_ragged_offset;
void *devOffsetsO = static_cast<int8_t *>(devOffsetsV) + num_bytes_per_ragged_offset;
const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout);
cu_seqlens_padded_to_offsets<<<grid, nthreads_per_block, 0, stream>>>(
layout_group, b, h, hg, d_qk, d_v, static_cast<int32_t *>(devPtrSeqOffsetsQ),
static_cast<int32_t *>(devPtrSeqOffsetsKV), static_cast<int32_t *>(devOffsetsQ),
static_cast<int32_t *>(devOffsetsK), static_cast<int32_t *>(devOffsetsV),
static_cast<int32_t *>(devOffsetsO));
static_cast<int32_t *>(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK,
devOffsetsV, devOffsetsO);
variant_pack[offset_q] = devOffsetsQ;
variant_pack[offset_k] = devOffsetsK;
variant_pack[offset_v] = devOffsetsV;
......
......@@ -4,6 +4,8 @@
* See LICENSE for license information.
************************************************************************/
#include <algorithm>
#include "../common.h"
#include "transformer_engine/fused_attn.h"
#include "utils.h"
......@@ -337,7 +339,7 @@ cudnn_frontend::Operation ternary_pw_op_create(cudnn_frontend::Tensor const &xDe
}
// convert cu_seqlens_q to qkv/o_ragged_offset and actual_seqlens_q
__global__ void cu_seqlens_to_offsets(size_t b, size_t h, size_t d, int32_t *cu_seqlens_q,
__global__ void cu_seqlens_to_offsets(int64_t b, int64_t h, int64_t d, int32_t *cu_seqlens_q,
int32_t *actual_seqlens_q, int32_t *qkv_ragged_offset,
int32_t *o_ragged_offset) {
size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
......@@ -362,12 +364,13 @@ __global__ void cu_seqlens_to_actual_seqlens(size_t b, int32_t const *const q_cu
}
// convert cu_seqlens_padded to offsets
__global__ void cu_seqlens_padded_to_offsets(NVTE_QKV_Layout_Group layout_group, size_t b, size_t h,
size_t hg, size_t d_qk, size_t d_v,
int32_t *cu_seqlens_q_padded,
int32_t *cu_seqlens_kv_padded, int32_t *offsets_q,
int32_t *offsets_k, int32_t *offsets_v,
int32_t *offsets_o) {
template <class OFFSETS_T>
__device__ void cu_seqlens_padded_to_offsets_impl(NVTE_QKV_Layout_Group layout_group, int64_t b,
int64_t h, int64_t hg, int64_t d_qk, int64_t d_v,
const int32_t *cu_seqlens_q_padded,
const int32_t *cu_seqlens_kv_padded,
OFFSETS_T *offsets_q, OFFSETS_T *offsets_k,
OFFSETS_T *offsets_v, OFFSETS_T *offsets_o) {
size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < b + 1) {
offsets_o[tid] = h * d_v * cu_seqlens_q_padded[tid];
......@@ -393,6 +396,60 @@ __global__ void cu_seqlens_padded_to_offsets(NVTE_QKV_Layout_Group layout_group,
}
}
__global__ void cu_seqlens_padded_to_offsets(NVTE_QKV_Layout_Group layout_group, int64_t b,
int64_t h, int64_t hg, int64_t d_qk, int64_t d_v,
const int32_t *cu_seqlens_q_padded,
const int32_t *cu_seqlens_kv_padded,
DType offset_dtype, void *offsets_q, void *offsets_k,
void *offsets_v, void *offsets_o) {
if (offset_dtype == DType::kInt32) {
cu_seqlens_padded_to_offsets_impl<int32_t>(
layout_group, b, h, hg, d_qk, d_v, cu_seqlens_q_padded, cu_seqlens_kv_padded,
reinterpret_cast<int32_t *>(offsets_q), reinterpret_cast<int32_t *>(offsets_k),
reinterpret_cast<int32_t *>(offsets_v), reinterpret_cast<int32_t *>(offsets_o));
} else {
assert(offset_dtype == DType::kInt64 && "expect int64");
cu_seqlens_padded_to_offsets_impl<int64_t>(
layout_group, b, h, hg, d_qk, d_v, cu_seqlens_q_padded, cu_seqlens_kv_padded,
reinterpret_cast<int64_t *>(offsets_q), reinterpret_cast<int64_t *>(offsets_k),
reinterpret_cast<int64_t *>(offsets_v), reinterpret_cast<int64_t *>(offsets_o));
}
}
DType get_ragged_offset_dtype(NVTE_QKV_Layout_Group layout_group, int64_t num_attn_heads,
int64_t num_gqa_groups, int64_t max_seqlen_q, int64_t max_seqlen_kv,
int64_t head_dim_qk, int64_t head_dim_v) {
std::array<int64_t, 4> offsets_qkvo{};
switch (layout_group) {
case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD:
offsets_qkvo[0] = num_attn_heads * head_dim_qk * max_seqlen_q;
offsets_qkvo[1] = num_gqa_groups * head_dim_qk * max_seqlen_kv;
offsets_qkvo[2] = num_gqa_groups * head_dim_v * max_seqlen_kv;
break;
case NVTE_QKV_Layout_Group::NVTE_3HD:
case NVTE_QKV_Layout_Group::NVTE_H3D:
offsets_qkvo[0] = 3 * num_attn_heads * head_dim_qk * max_seqlen_q;
offsets_qkvo[1] = offsets_qkvo[0];
offsets_qkvo[2] = offsets_qkvo[0];
break;
case NVTE_QKV_Layout_Group::NVTE_HD_2HD:
case NVTE_QKV_Layout_Group::NVTE_HD_H2D:
offsets_qkvo[0] = num_attn_heads * head_dim_qk * max_seqlen_q;
offsets_qkvo[1] = 2 * num_gqa_groups * head_dim_qk * max_seqlen_kv;
offsets_qkvo[2] = offsets_qkvo[1];
break;
}
offsets_qkvo[3] = num_attn_heads * head_dim_qk * max_seqlen_q;
size_t max_offset = *std::max_element(offsets_qkvo.begin(), offsets_qkvo.end());
if (max_offset > std::numeric_limits<int32_t>::max()) {
return DType::kInt64;
}
return DType::kInt32;
}
} // namespace fused_attn
// get cuDNN data type
......
......@@ -118,7 +118,7 @@ struct FADescriptor_v1 {
}
};
__global__ void cu_seqlens_to_offsets(size_t b, size_t h, size_t d, int32_t *cu_seqlens_q,
__global__ void cu_seqlens_to_offsets(int64_t b, int64_t h, int64_t d, int32_t *cu_seqlens_q,
int32_t *actual_seqlens_q, int32_t *qkv_ragged_offset,
int32_t *o_ragged_offset);
......@@ -126,12 +126,17 @@ __global__ void cu_seqlens_to_actual_seqlens(size_t b, int32_t const *const q_cu
int32_t const *const kv_cu_seqlens, int32_t *q_seqlens,
int32_t *kv_seqlens);
__global__ void cu_seqlens_padded_to_offsets(NVTE_QKV_Layout_Group layout_group, size_t b, size_t h,
size_t hg, size_t d_qk, size_t d_v,
int32_t *cu_seqlens_q_padded,
int32_t *cu_seqlens_kv_padded, int32_t *offsets_q,
int32_t *offsets_k, int32_t *offsets_v,
int32_t *offsets_o);
__global__ void cu_seqlens_padded_to_offsets(NVTE_QKV_Layout_Group layout_group, int64_t b,
int64_t h, int64_t hg, int64_t d_qk, int64_t d_v,
const int32_t *cu_seqlens_q_padded,
const int32_t *cu_seqlens_kv_padded,
DType offset_dtype, void *offsets_q, void *offsets_k,
void *offsets_v, void *offsets_o);
DType get_ragged_offset_dtype(NVTE_QKV_Layout_Group layout_group, int64_t num_attn_heads,
int64_t num_gqa_groups, int64_t max_seqlen_q, int64_t max_seqlen_kv,
int64_t head_dim_qk, int64_t head_dim_v);
} // namespace fused_attn
cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t);
......
......@@ -24,7 +24,7 @@ extern "C" {
enum NVTEDType {
kNVTEByte = 0, /*!< Byte */
kNVTEInt32 = 1, /*!< 32-bit integer */
kNVTEInt64 = 2, /*!< 32-bit integer */
kNVTEInt64 = 2, /*!< 64-bit integer */
kNVTEFloat32 = 3, /*!< 32-bit float */
kNVTEFloat16 = 4, /*!< 16-bit float (E5M10) */
kNVTEBFloat16 = 5, /*!< 16-bit bfloat (E8M7) */
......
......@@ -540,9 +540,7 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto dqkv = buffers[12];
auto dqkv_tensor = TensorWrapper(dqkv, qkv_shape, dtype);
if (is_ragged) {
size_t dqkv_size =
std::accumulate(qkv_shape.cbegin(), qkv_shape.cend(), 1, std::multiplies<size_t>());
cudaMemsetAsync(dqkv, 0, dqkv_size * typeToSize(dtype), stream);
cudaMemsetAsync(dqkv, 0, product(qkv_shape) * typeToSize(dtype), stream);
}
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
......@@ -564,12 +562,8 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto dkv = buffers[13];
auto dkv_tensor = TensorWrapper(dkv, kv_shape, dtype);
if (is_ragged) {
size_t dq_size =
std::accumulate(q_shape.cbegin(), q_shape.cend(), 1, std::multiplies<size_t>());
size_t dkv_size =
std::accumulate(kv_shape.cbegin(), kv_shape.cend(), 1, std::multiplies<size_t>());
cudaMemsetAsync(dq, 0, dq_size * typeToSize(dtype), stream);
cudaMemsetAsync(dkv, 0, dkv_size * typeToSize(dtype), stream);
cudaMemsetAsync(dq, 0, product(q_shape) * typeToSize(dtype), stream);
cudaMemsetAsync(dkv, 0, product(kv_shape) * typeToSize(dtype), stream);
}
nvte_fused_attn_bwd_kvpacked(
q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
......@@ -597,14 +591,9 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto dv = buffers[14];
auto dv_tensor = TensorWrapper(dv, v_shape, dtype);
if (is_ragged) {
size_t dq_size =
std::accumulate(q_shape.cbegin(), q_shape.cend(), 1, std::multiplies<size_t>());
size_t dk_size =
std::accumulate(k_shape.cbegin(), k_shape.cend(), 1, std::multiplies<size_t>());
size_t dv_size = dk_size;
cudaMemsetAsync(dq, 0, dq_size * typeToSize(dtype), stream);
cudaMemsetAsync(dk, 0, dk_size * typeToSize(dtype), stream);
cudaMemsetAsync(dv, 0, dv_size * typeToSize(dtype), stream);
cudaMemsetAsync(dq, 0, product(q_shape) * typeToSize(dtype), stream);
cudaMemsetAsync(dk, 0, product(k_shape) * typeToSize(dtype), stream);
cudaMemsetAsync(dv, 0, product(v_shape) * typeToSize(dtype), stream);
}
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment