Commit 2b05e121 authored by yuguo's avatar yuguo
Browse files

Merge commit 'a69692ac' of...

Merge commit 'a69692ac' of https://github.com/NVIDIA/TransformerEngine
parents 0fd441c2 a69692ac
...@@ -325,7 +325,7 @@ void thd_read_half_tensor(const Tensor &tensor, const Tensor &cu_seqlens, Tensor ...@@ -325,7 +325,7 @@ void thd_read_half_tensor(const Tensor &tensor, const Tensor &cu_seqlens, Tensor
int batch = cu_seqlens_shape[0] - 1; int batch = cu_seqlens_shape[0] - 1;
int num_heads = tensor_shape[seq_dim + 1]; int num_heads = tensor_shape[seq_dim + 1];
int dim_per_head = tensor_shape[seq_dim + 2]; int dim_per_head = tensor_shape[seq_dim + 2];
int hidden_size_in_bytes = num_heads * dim_per_head * typeToSize(tensor.dtype()); int hidden_size_in_bytes = (num_heads * dim_per_head * typeToNumBits(tensor.dtype())) / 8;
// For 128-bits load/store // For 128-bits load/store
NVTE_CHECK(hidden_size_in_bytes % 16 == 0); NVTE_CHECK(hidden_size_in_bytes % 16 == 0);
...@@ -582,7 +582,7 @@ static void thd_grad_correction_helper(Tensor grad, const Tensor &grad_per_step, ...@@ -582,7 +582,7 @@ static void thd_grad_correction_helper(Tensor grad, const Tensor &grad_per_step,
NVTE_CHECK(grad_per_step_shape[seq_dim + 2] == dim_per_head); NVTE_CHECK(grad_per_step_shape[seq_dim + 2] == dim_per_head);
size_t hidden_size = num_heads * dim_per_head; size_t hidden_size = num_heads * dim_per_head;
NVTE_CHECK((hidden_size * typeToSize(grad.dtype())) % 16 == 0); NVTE_CHECK(((hidden_size * typeToNumBits(grad.dtype())) / 8) % 16 == 0);
constexpr unsigned int block = 256; constexpr unsigned int block = 256;
unsigned int grid_x; unsigned int grid_x;
...@@ -677,9 +677,9 @@ void nvte_cp_thd_read_half_tensor(const NVTETensor &tensor, const NVTETensor &cu ...@@ -677,9 +677,9 @@ void nvte_cp_thd_read_half_tensor(const NVTETensor &tensor, const NVTETensor &cu
NVTE_API_CALL(nvte_thd_read_half_tensor); NVTE_API_CALL(nvte_thd_read_half_tensor);
using namespace transformer_engine; using namespace transformer_engine;
context_parallel::thd_read_half_tensor(*reinterpret_cast<Tensor *>(tensor), context_parallel::thd_read_half_tensor(*convertNVTETensorCheck(tensor),
*reinterpret_cast<Tensor *>(cu_seqlens), *convertNVTETensorCheck(cu_seqlens),
*reinterpret_cast<Tensor *>(half), half_idx, stream); *convertNVTETensorCheck(half), half_idx, stream);
} }
void nvte_cp_thd_second_half_lse_correction(NVTETensor lse, const NVTETensor &lse_per_step, void nvte_cp_thd_second_half_lse_correction(NVTETensor lse, const NVTETensor &lse_per_step,
...@@ -689,8 +689,8 @@ void nvte_cp_thd_second_half_lse_correction(NVTETensor lse, const NVTETensor &ls ...@@ -689,8 +689,8 @@ void nvte_cp_thd_second_half_lse_correction(NVTETensor lse, const NVTETensor &ls
using namespace transformer_engine; using namespace transformer_engine;
context_parallel::thd_second_half_lse_correction( context_parallel::thd_second_half_lse_correction(
*reinterpret_cast<Tensor *>(lse), *reinterpret_cast<Tensor *>(lse_per_step), *convertNVTETensorCheck(lse), *convertNVTETensorCheck(lse_per_step),
*reinterpret_cast<Tensor *>(cu_seqlens), lse_packed, stream); *convertNVTETensorCheck(cu_seqlens), lse_packed, stream);
} }
void nvte_cp_thd_read_second_half_lse(const NVTETensor &lse, const NVTETensor &cu_seqlens, void nvte_cp_thd_read_second_half_lse(const NVTETensor &lse, const NVTETensor &cu_seqlens,
...@@ -700,8 +700,8 @@ void nvte_cp_thd_read_second_half_lse(const NVTETensor &lse, const NVTETensor &c ...@@ -700,8 +700,8 @@ void nvte_cp_thd_read_second_half_lse(const NVTETensor &lse, const NVTETensor &c
using namespace transformer_engine; using namespace transformer_engine;
context_parallel::thd_read_second_half_lse( context_parallel::thd_read_second_half_lse(
*reinterpret_cast<Tensor *>(lse), *reinterpret_cast<Tensor *>(cu_seqlens), *convertNVTETensorCheck(lse), *convertNVTETensorCheck(cu_seqlens),
*reinterpret_cast<Tensor *>(half_lse), lse_packed, second_half_lse_seqlen, stream); *convertNVTETensorCheck(half_lse), lse_packed, second_half_lse_seqlen, stream);
} }
void nvte_cp_thd_out_correction(NVTETensor out, const NVTETensor &out_per_step, void nvte_cp_thd_out_correction(NVTETensor out, const NVTETensor &out_per_step,
...@@ -712,9 +712,9 @@ void nvte_cp_thd_out_correction(NVTETensor out, const NVTETensor &out_per_step, ...@@ -712,9 +712,9 @@ void nvte_cp_thd_out_correction(NVTETensor out, const NVTETensor &out_per_step,
using namespace transformer_engine; using namespace transformer_engine;
context_parallel::thd_out_correction( context_parallel::thd_out_correction(
*reinterpret_cast<Tensor *>(out), *reinterpret_cast<Tensor *>(out_per_step), *convertNVTETensorCheck(out), *convertNVTETensorCheck(out_per_step),
*reinterpret_cast<Tensor *>(lse), *reinterpret_cast<Tensor *>(lse_per_step), *convertNVTETensorCheck(lse), *convertNVTETensorCheck(lse_per_step),
*reinterpret_cast<Tensor *>(cu_seqlens), only_second_half, lse_packed, stream); *convertNVTETensorCheck(cu_seqlens), only_second_half, lse_packed, stream);
} }
void nvte_cp_thd_grad_correction(NVTETensor grad, const NVTETensor &grad_per_step, void nvte_cp_thd_grad_correction(NVTETensor grad, const NVTETensor &grad_per_step,
...@@ -727,8 +727,8 @@ void nvte_cp_thd_grad_correction(NVTETensor grad, const NVTETensor &grad_per_ste ...@@ -727,8 +727,8 @@ void nvte_cp_thd_grad_correction(NVTETensor grad, const NVTETensor &grad_per_ste
std::string second_half_str(second_half); std::string second_half_str(second_half);
context_parallel::thd_grad_correction( context_parallel::thd_grad_correction(
*reinterpret_cast<Tensor *>(grad), *reinterpret_cast<Tensor *>(grad_per_step), *convertNVTETensorCheck(grad), *convertNVTETensorCheck(grad_per_step),
*reinterpret_cast<Tensor *>(cu_seqlens), first_half_str, second_half_str, stream); *convertNVTETensorCheck(cu_seqlens), first_half_str, second_half_str, stream);
} }
void nvte_cp_thd_get_partitioned_indices(const NVTETensor &cu_seqlens, NVTETensor output, void nvte_cp_thd_get_partitioned_indices(const NVTETensor &cu_seqlens, NVTETensor output,
...@@ -737,7 +737,7 @@ void nvte_cp_thd_get_partitioned_indices(const NVTETensor &cu_seqlens, NVTETenso ...@@ -737,7 +737,7 @@ void nvte_cp_thd_get_partitioned_indices(const NVTETensor &cu_seqlens, NVTETenso
NVTE_API_CALL(nvte_thd_get_partitioned_indices); NVTE_API_CALL(nvte_thd_get_partitioned_indices);
using namespace transformer_engine; using namespace transformer_engine;
context_parallel::thd_get_partitioned_indices(*reinterpret_cast<Tensor *>(cu_seqlens), context_parallel::thd_get_partitioned_indices(*convertNVTETensorCheck(cu_seqlens),
*reinterpret_cast<Tensor *>(output), total_tokens, *convertNVTETensorCheck(output), total_tokens,
world_size, rank, stream); world_size, rank, stream);
} }
...@@ -138,8 +138,8 @@ void nvte_prepare_flash_attn_fwd(NVTETensor qkvi, NVTETensor qkv, cudaStream_t s ...@@ -138,8 +138,8 @@ void nvte_prepare_flash_attn_fwd(NVTETensor qkvi, NVTETensor qkv, cudaStream_t s
NVTE_API_CALL(nvte_prepare_flash_attn_fwd); NVTE_API_CALL(nvte_prepare_flash_attn_fwd);
using namespace transformer_engine; using namespace transformer_engine;
flash_attention::prepare_flash_attn_fwd(*reinterpret_cast<Tensor *>(qkvi), flash_attention::prepare_flash_attn_fwd(*convertNVTETensorCheck(qkvi),
*reinterpret_cast<Tensor *>(qkv), stream); *convertNVTETensorCheck(qkv), stream);
} }
void nvte_prepare_flash_attn_bwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTETensor qkv, void nvte_prepare_flash_attn_bwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTETensor qkv,
...@@ -147,7 +147,7 @@ void nvte_prepare_flash_attn_bwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTET ...@@ -147,7 +147,7 @@ void nvte_prepare_flash_attn_bwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTET
NVTE_API_CALL(nvte_prepare_flash_attn_bwd); NVTE_API_CALL(nvte_prepare_flash_attn_bwd);
using namespace transformer_engine; using namespace transformer_engine;
flash_attention::prepare_flash_attn_bwd( flash_attention::prepare_flash_attn_bwd(*convertNVTETensorCheck(q), *convertNVTETensorCheck(k),
*reinterpret_cast<Tensor *>(q), *reinterpret_cast<Tensor *>(k), *convertNVTETensorCheck(v), *convertNVTETensorCheck(qkv),
*reinterpret_cast<Tensor *>(v), *reinterpret_cast<Tensor *>(qkv), stream); stream);
} }
...@@ -134,10 +134,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) { ...@@ -134,10 +134,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) {
// select a backend for fused attention // select a backend for fused attention
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout,
NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk,
int64_t window_size_left, int64_t window_size_right) { size_t head_dim_v, int64_t window_size_left, int64_t window_size_right) {
using namespace transformer_engine; using namespace transformer_engine;
NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
const int device_id = cuda::current_device(); const int device_id = cuda::current_device();
...@@ -216,12 +216,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -216,12 +216,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
} }
if ( if (
// TODO(cyang): replace with cudnn-frontend check_support for cleaner logic and better error messaging // TODO(cyang): replace with cudnn-frontend check_support for cleaner logic and better error messaging
// special conditions for blackwell
// TODO: enable THD max_t in f16_arbitrary_seqlen when support becomes available in 9.7
!(sm_arch_ >= 100 && (head_dim_qk > 128 || head_dim_v > 128)) &&
// architecture // architecture
((cudnn_runtime_version >= 8903 && sm_arch_ >= 80) || ((cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90)) ||
(cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90))) && (cudnn_runtime_version >= 8903 && sm_arch_ >= 80 && sm_arch_ < 100) ||
(cudnn_runtime_version >= 90700 && sm_arch_ >= 80)) &&
// sequence length // sequence length
((cudnn_runtime_version < 90000 && max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0) || ((cudnn_runtime_version < 90000 && max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0) ||
(cudnn_runtime_version >= 90000)) && (cudnn_runtime_version >= 90000)) &&
...@@ -229,11 +227,28 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -229,11 +227,28 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
((cudnn_runtime_version < 8907 && num_attn_heads == num_gqa_groups) || ((cudnn_runtime_version < 8907 && num_attn_heads == num_gqa_groups) ||
(cudnn_runtime_version >= 8907)) && (cudnn_runtime_version >= 8907)) &&
// head dimension // head dimension
((head_dim_qk <= 128 && head_dim_qk % 8 == 0 && head_dim_v <= 128 && head_dim_v % 8 == 0) || // multiples of 8
// TODO (cyang): add is_training to nvte_get_fused_attn_backend (head_dim_qk % 8 == 0 && head_dim_v % 8 == 0 &&
// d=256 only supported for forward // <= 128
(sm_arch_ >= 90 && cudnn_runtime_version >= 90000 && head_dim_qk <= 256 && ((head_dim_qk <= 128 && head_dim_v <= 128) ||
head_dim_qk % 8 == 0 && head_dim_v <= 256 && head_dim_v % 8 == 0)) && // 9.1: <= 256 + Hopper + fprop
// 9.5: <= 256 + Hopper + bprop
(head_dim_qk <= 256 && head_dim_v <= 256 &&
((!is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90100) ||
(is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90500))) ||
// 9.9: any head_dim + Blackwell + fprop + non_paged + sq > 1
(!is_training && sm_arch_ >= 100 && cudnn_runtime_version >= 90900 && max_seqlen_q > 1 &&
layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) ||
// 9.10: any head_dim + any arch + fprop + paged
// 9.10: any head_dim + any arch + fprop + non_paged + sq > 1
// 9.10: any head_dim + any arch + fprop + non_paged + sq = 1 + {no_mask, padding, BRCM, padding_BRCM}
(!is_training && cudnn_runtime_version >= 91000 &&
(layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD || max_seqlen_q > 1 ||
(max_seqlen_q == 1 && attn_mask_type != NVTE_Mask_Type::NVTE_CAUSAL_MASK &&
attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) ||
// 9.11: d_qk = 192, d_v = 128 + Blackwell + bprop + non-paged
(head_dim_qk == 192 && head_dim_v == 128 && is_training && sm_arch_ >= 100 &&
cudnn_runtime_version >= 91100))) &&
// bias type // bias type
((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || ((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) ||
(cudnn_runtime_version >= 8906 && (cudnn_runtime_version >= 8906 &&
...@@ -392,14 +407,14 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, ...@@ -392,14 +407,14 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked); NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens = reinterpret_cast<const Tensor *>(cu_seqlens); const Tensor *input_cu_seqlens = convertNVTETensorCheck(cu_seqlens);
const Tensor *input_cu_seqlens_padded = reinterpret_cast<const Tensor *>(cu_seqlens_padded); const Tensor *input_cu_seqlens_padded = convertNVTETensorCheck(cu_seqlens_padded);
const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(rng_state); const Tensor *input_rng_state = convertNVTETensorCheck(rng_state);
const Tensor *input_QKV = reinterpret_cast<const Tensor *>(QKV); const Tensor *input_QKV = convertNVTETensorCheck(QKV);
const Tensor *input_Bias = reinterpret_cast<const Tensor *>(Bias); const Tensor *input_Bias = convertNVTETensorCheck(Bias);
Tensor *input_output_S = reinterpret_cast<Tensor *>(S); Tensor *input_output_S = convertNVTETensorCheck(S);
Tensor *output_O = reinterpret_cast<Tensor *>(O); Tensor *output_O = convertNVTETensorCheck(O);
Tensor *wkspace = reinterpret_cast<Tensor *>(workspace); Tensor *wkspace = convertNVTETensor(workspace);
auto ndim = input_QKV->data.shape.size(); auto ndim = input_QKV->data.shape.size();
size_t b = input_cu_seqlens->data.shape[0] - 1; size_t b = input_cu_seqlens->data.shape[0] - 1;
...@@ -423,8 +438,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, ...@@ -423,8 +438,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype); const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen, is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h,
max_seqlen, d, d, window_size_left, window_size_right); max_seqlen, max_seqlen, d, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
...@@ -472,16 +487,16 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con ...@@ -472,16 +487,16 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked); NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens = reinterpret_cast<const Tensor *>(cu_seqlens); const Tensor *input_cu_seqlens = convertNVTETensorCheck(cu_seqlens);
const Tensor *input_cu_seqlens_padded = reinterpret_cast<const Tensor *>(cu_seqlens_padded); const Tensor *input_cu_seqlens_padded = convertNVTETensorCheck(cu_seqlens_padded);
const Tensor *input_QKV = reinterpret_cast<const Tensor *>(QKV); const Tensor *input_QKV = convertNVTETensorCheck(QKV);
const Tensor *input_O = reinterpret_cast<const Tensor *>(O); const Tensor *input_O = convertNVTETensorCheck(O);
const Tensor *input_dO = reinterpret_cast<const Tensor *>(dO); const Tensor *input_dO = convertNVTETensorCheck(dO);
const Tensor *input_S = reinterpret_cast<const Tensor *>(S); const Tensor *input_S = convertNVTETensorCheck(S);
Tensor *input_output_dP = reinterpret_cast<Tensor *>(dP); Tensor *input_output_dP = convertNVTETensorCheck(dP);
Tensor *output_dQKV = reinterpret_cast<Tensor *>(dQKV); Tensor *output_dQKV = convertNVTETensorCheck(dQKV);
Tensor *output_dBias = reinterpret_cast<Tensor *>(dBias); Tensor *output_dBias = convertNVTETensorCheck(dBias);
Tensor *wkspace = reinterpret_cast<Tensor *>(workspace); Tensor *wkspace = convertNVTETensor(workspace);
auto ndim = input_QKV->data.shape.size(); auto ndim = input_QKV->data.shape.size();
size_t b = input_cu_seqlens->data.shape[0] - 1; size_t b = input_cu_seqlens->data.shape[0] - 1;
...@@ -505,12 +520,12 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con ...@@ -505,12 +520,12 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype); const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen, true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen,
max_seqlen, d, d, window_size_left, window_size_right); max_seqlen, d, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
fused_attn_max_512_bwd_qkvpacked( fused_attn_max_512_bwd_qkvpacked(
b, h, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_QKV, b, h, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_QKV,
input_dO, output_S, output_dQKV, output_dBias, input_cu_seqlens, wkspace, stream, handle); input_dO, output_S, output_dQKV, output_dBias, input_cu_seqlens, wkspace, stream, handle);
...@@ -519,13 +534,13 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con ...@@ -519,13 +534,13 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
#endif #endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Tensor *input_Bias, *input_rng_state; Tensor *input_Bias, *input_rng_state;
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
input_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]); input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
input_Bias = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[2]); input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
} else { } else {
input_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]); input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
} }
fused_attn_arbitrary_seqlen_bwd_qkvpacked( fused_attn_arbitrary_seqlen_bwd_qkvpacked(
b, h, max_seqlen, d, t, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, b, h, max_seqlen, d, t, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
...@@ -540,9 +555,9 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con ...@@ -540,9 +555,9 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
#endif #endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
const Tensor *input_M = reinterpret_cast<const Tensor *>(Aux_CTX_Tensors->tensors[0]); const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
const Tensor *input_ZInv = reinterpret_cast<const Tensor *>(Aux_CTX_Tensors->tensors[1]); const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(Aux_CTX_Tensors->tensors[2]); const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
fused_attn_fp8_bwd_qkvpacked(b, h, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type, fused_attn_fp8_bwd_qkvpacked(b, h, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type,
attn_mask_type, input_QKV, input_O, input_dO, input_M, input_ZInv, attn_mask_type, input_QKV, input_O, input_dO, input_M, input_ZInv,
input_S, input_output_dP, output_dQKV, input_cu_seqlens, input_S, input_output_dP, output_dQKV, input_cu_seqlens,
...@@ -566,19 +581,19 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -566,19 +581,19 @@ void nvte_fused_attn_fwd_kvpacked(
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked); NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q); const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor *>(cu_seqlens_kv); const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv);
const Tensor *input_cu_seqlens_q_padded = reinterpret_cast<const Tensor *>(cu_seqlens_q_padded); const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded);
const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast<const Tensor *>(cu_seqlens_kv_padded); const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded);
const Tensor *input_page_table_k = reinterpret_cast<const Tensor *>(page_table_k); const Tensor *input_page_table_k = convertNVTETensorCheck(page_table_k);
const Tensor *input_page_table_v = reinterpret_cast<const Tensor *>(page_table_v); const Tensor *input_page_table_v = convertNVTETensorCheck(page_table_v);
const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(rng_state); const Tensor *input_rng_state = convertNVTETensorCheck(rng_state);
const Tensor *input_Q = reinterpret_cast<const Tensor *>(Q); const Tensor *input_Q = convertNVTETensorCheck(Q);
const Tensor *input_KV = reinterpret_cast<const Tensor *>(KV); const Tensor *input_KV = convertNVTETensorCheck(KV);
const Tensor *input_Bias = reinterpret_cast<const Tensor *>(Bias); const Tensor *input_Bias = convertNVTETensorCheck(Bias);
Tensor *input_output_S = reinterpret_cast<Tensor *>(S); Tensor *input_output_S = convertNVTETensorCheck(S);
Tensor *output_O = reinterpret_cast<Tensor *>(O); Tensor *output_O = convertNVTETensorCheck(O);
Tensor *wkspace = reinterpret_cast<Tensor *>(workspace); Tensor *wkspace = convertNVTETensor(workspace);
size_t b = input_cu_seqlens_q->data.shape[0] - 1; size_t b = input_cu_seqlens_q->data.shape[0] - 1;
auto ndim = input_Q->data.shape.size(); auto ndim = input_Q->data.shape.size();
...@@ -636,8 +651,8 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -636,8 +651,8 @@ void nvte_fused_attn_fwd_kvpacked(
const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype); const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv,
max_seqlen_kv, d, d, window_size_left, window_size_right); max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
...@@ -686,20 +701,20 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -686,20 +701,20 @@ void nvte_fused_attn_bwd_kvpacked(
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked); NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q); const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor *>(cu_seqlens_kv); const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv);
const Tensor *input_cu_seqlens_q_padded = reinterpret_cast<const Tensor *>(cu_seqlens_q_padded); const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded);
const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast<const Tensor *>(cu_seqlens_kv_padded); const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded);
const Tensor *input_Q = reinterpret_cast<const Tensor *>(Q); const Tensor *input_Q = convertNVTETensorCheck(Q);
const Tensor *input_KV = reinterpret_cast<const Tensor *>(KV); const Tensor *input_KV = convertNVTETensorCheck(KV);
const Tensor *input_O = reinterpret_cast<const Tensor *>(O); const Tensor *input_O = convertNVTETensorCheck(O);
const Tensor *input_dO = reinterpret_cast<const Tensor *>(dO); const Tensor *input_dO = convertNVTETensorCheck(dO);
const Tensor *input_S = reinterpret_cast<const Tensor *>(S); const Tensor *input_S = convertNVTETensorCheck(S);
Tensor *input_output_dP = reinterpret_cast<Tensor *>(dP); Tensor *input_output_dP = convertNVTETensorCheck(dP);
Tensor *output_dQ = reinterpret_cast<Tensor *>(dQ); Tensor *output_dQ = convertNVTETensorCheck(dQ);
Tensor *output_dKV = reinterpret_cast<Tensor *>(dKV); Tensor *output_dKV = convertNVTETensorCheck(dKV);
Tensor *output_dBias = reinterpret_cast<Tensor *>(dBias); Tensor *output_dBias = convertNVTETensorCheck(dBias);
Tensor *wkspace = reinterpret_cast<Tensor *>(workspace); Tensor *wkspace = convertNVTETensor(workspace);
size_t b = input_cu_seqlens_q->data.shape[0] - 1; size_t b = input_cu_seqlens_q->data.shape[0] - 1;
auto ndim = input_Q->data.shape.size(); auto ndim = input_Q->data.shape.size();
...@@ -731,12 +746,12 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -731,12 +746,12 @@ void nvte_fused_attn_bwd_kvpacked(
const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype); const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv,
max_seqlen_kv, d, d, window_size_left, window_size_right); max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
fused_attn_max_512_bwd_kvpacked( fused_attn_max_512_bwd_kvpacked(
b, h_q, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type, b, h_q, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type,
attn_mask_type, input_Q, input_KV, input_dO, output_S, output_dQ, output_dKV, output_dBias, attn_mask_type, input_Q, input_KV, input_dO, output_S, output_dQ, output_dKV, output_dBias,
...@@ -746,13 +761,13 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -746,13 +761,13 @@ void nvte_fused_attn_bwd_kvpacked(
#endif #endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8903) #if (CUDNN_VERSION >= 8903)
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Tensor *input_Bias, *input_rng_state; Tensor *input_Bias, *input_rng_state;
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
input_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]); input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
input_Bias = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[2]); input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
} else { } else {
input_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]); input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
} }
fused_attn_arbitrary_seqlen_bwd_kvpacked( fused_attn_arbitrary_seqlen_bwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, attn_scale, dropout, qkv_layout, b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, attn_scale, dropout, qkv_layout,
...@@ -768,9 +783,9 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -768,9 +783,9 @@ void nvte_fused_attn_bwd_kvpacked(
#endif #endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
const Tensor *input_M = reinterpret_cast<const Tensor *>(Aux_CTX_Tensors->tensors[0]); const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
const Tensor *input_ZInv = reinterpret_cast<const Tensor *>(Aux_CTX_Tensors->tensors[1]); const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(Aux_CTX_Tensors->tensors[2]); const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
fused_attn_fp8_bwd_kvpacked(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, fused_attn_fp8_bwd_kvpacked(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout,
qkv_layout, bias_type, attn_mask_type, input_Q, input_KV, input_O, qkv_layout, bias_type, attn_mask_type, input_Q, input_KV, input_O,
input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ, input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ,
...@@ -797,20 +812,20 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -797,20 +812,20 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd); NVTE_API_CALL(nvte_flash_attn_fwd);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q); const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor *>(cu_seqlens_kv); const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv);
const Tensor *input_cu_seqlens_q_padded = reinterpret_cast<const Tensor *>(cu_seqlens_q_padded); const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded);
const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast<const Tensor *>(cu_seqlens_kv_padded); const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded);
const Tensor *input_page_table_k = reinterpret_cast<const Tensor *>(page_table_k); const Tensor *input_page_table_k = convertNVTETensorCheck(page_table_k);
const Tensor *input_page_table_v = reinterpret_cast<const Tensor *>(page_table_v); const Tensor *input_page_table_v = convertNVTETensorCheck(page_table_v);
const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(rng_state); const Tensor *input_rng_state = convertNVTETensorCheck(rng_state);
const Tensor *input_Q = reinterpret_cast<const Tensor *>(Q); const Tensor *input_Q = convertNVTETensorCheck(Q);
const Tensor *input_K = reinterpret_cast<const Tensor *>(K); const Tensor *input_K = convertNVTETensorCheck(K);
const Tensor *input_V = reinterpret_cast<const Tensor *>(V); const Tensor *input_V = convertNVTETensorCheck(V);
const Tensor *input_Bias = reinterpret_cast<const Tensor *>(Bias); const Tensor *input_Bias = convertNVTETensorCheck(Bias);
Tensor *input_output_S = reinterpret_cast<Tensor *>(S); Tensor *input_output_S = convertNVTETensorCheck(S);
Tensor *output_O = reinterpret_cast<Tensor *>(O); Tensor *output_O = convertNVTETensorCheck(O);
Tensor *wkspace = reinterpret_cast<Tensor *>(workspace); Tensor *wkspace = convertNVTETensor(workspace);
auto ndim = input_Q->data.shape.size(); auto ndim = input_Q->data.shape.size();
auto ndim_kv = input_K->data.shape.size(); auto ndim_kv = input_K->data.shape.size();
...@@ -862,8 +877,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -862,8 +877,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const NVTEDType KV_type = static_cast<NVTEDType>(input_K->data.dtype); const NVTEDType KV_type = static_cast<NVTEDType>(input_K->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv,
max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
...@@ -914,22 +929,22 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -914,22 +929,22 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd); NVTE_API_CALL(nvte_flash_attn_bwd);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q); const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor *>(cu_seqlens_kv); const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv);
const Tensor *input_cu_seqlens_q_padded = reinterpret_cast<const Tensor *>(cu_seqlens_q_padded); const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded);
const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast<const Tensor *>(cu_seqlens_kv_padded); const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded);
const Tensor *input_Q = reinterpret_cast<const Tensor *>(Q); const Tensor *input_Q = convertNVTETensorCheck(Q);
const Tensor *input_K = reinterpret_cast<const Tensor *>(K); const Tensor *input_K = convertNVTETensorCheck(K);
const Tensor *input_V = reinterpret_cast<const Tensor *>(V); const Tensor *input_V = convertNVTETensorCheck(V);
const Tensor *input_O = reinterpret_cast<const Tensor *>(O); const Tensor *input_O = convertNVTETensorCheck(O);
const Tensor *input_dO = reinterpret_cast<const Tensor *>(dO); const Tensor *input_dO = convertNVTETensorCheck(dO);
const Tensor *input_S = reinterpret_cast<const Tensor *>(S); const Tensor *input_S = convertNVTETensorCheck(S);
Tensor *input_output_dP = reinterpret_cast<Tensor *>(dP); Tensor *input_output_dP = convertNVTETensorCheck(dP);
Tensor *output_dQ = reinterpret_cast<Tensor *>(dQ); Tensor *output_dQ = convertNVTETensorCheck(dQ);
Tensor *output_dK = reinterpret_cast<Tensor *>(dK); Tensor *output_dK = convertNVTETensorCheck(dK);
Tensor *output_dV = reinterpret_cast<Tensor *>(dV); Tensor *output_dV = convertNVTETensorCheck(dV);
Tensor *output_dBias = reinterpret_cast<Tensor *>(dBias); Tensor *output_dBias = convertNVTETensorCheck(dBias);
Tensor *wkspace = reinterpret_cast<Tensor *>(workspace); Tensor *wkspace = convertNVTETensor(workspace);
auto ndim = input_Q->data.shape.size(); auto ndim = input_Q->data.shape.size();
auto ndim_kv = input_K->data.shape.size(); auto ndim_kv = input_K->data.shape.size();
...@@ -954,12 +969,12 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -954,12 +969,12 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const NVTEDType KV_type = static_cast<NVTEDType>(input_K->data.dtype); const NVTEDType KV_type = static_cast<NVTEDType>(input_K->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv,
max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
fused_attn_max_512_bwd(b, h_q, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout, fused_attn_max_512_bwd(b, h_q, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout,
qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V,
input_dO, output_S, output_dQ, output_dK, output_dV, output_dBias, input_dO, output_S, output_dQ, output_dK, output_dV, output_dBias,
...@@ -969,13 +984,13 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -969,13 +984,13 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
#endif #endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Tensor *input_Bias, *input_rng_state; Tensor *input_Bias, *input_rng_state;
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
input_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]); input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
input_Bias = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[2]); input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
} else { } else {
input_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]); input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
} }
fused_attn_arbitrary_seqlen_bwd( fused_attn_arbitrary_seqlen_bwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, attn_scale, dropout, b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, attn_scale, dropout,
...@@ -991,9 +1006,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -991,9 +1006,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
#endif #endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
const Tensor *input_M = reinterpret_cast<const Tensor *>(Aux_CTX_Tensors->tensors[0]); const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
const Tensor *input_ZInv = reinterpret_cast<const Tensor *>(Aux_CTX_Tensors->tensors[1]); const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(Aux_CTX_Tensors->tensors[2]); const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout, fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout,
qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, input_O, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, input_O,
input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ, input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ,
......
...@@ -377,7 +377,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( ...@@ -377,7 +377,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
const size_t num_bytes_per_seqlen = alignTo<16>(b * sizeof(int32_t)); const size_t num_bytes_per_seqlen = alignTo<16>(b * sizeof(int32_t));
const size_t actual_seqlen_workspace_size = is_padding ? 2 * num_bytes_per_seqlen : 0; const size_t actual_seqlen_workspace_size = is_padding ? 2 * num_bytes_per_seqlen : 0;
const size_t num_bytes_per_ragged_offset = const size_t num_bytes_per_ragged_offset =
alignTo<16>((b + 1) * typeToSize(ragged_offset_type)); alignTo<16>(((b + 1) * typeToNumBits(ragged_offset_type)) / 8);
size_t seqlen_offsets_workspace_size = 0; size_t seqlen_offsets_workspace_size = 0;
if (is_ragged_q || is_ragged_kv) { if (is_ragged_q || is_ragged_kv) {
size_t count = 2 * (static_cast<size_t>(is_ragged_q) + static_cast<size_t>(is_ragged_kv)); size_t count = 2 * (static_cast<size_t>(is_ragged_q) + static_cast<size_t>(is_ragged_kv));
...@@ -831,7 +831,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( ...@@ -831,7 +831,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
const size_t num_bytes_per_seqlen = alignTo<16>(b * sizeof(int32_t)); const size_t num_bytes_per_seqlen = alignTo<16>(b * sizeof(int32_t));
const size_t actual_seqlen_workspace_size = is_padding ? 2 * num_bytes_per_seqlen : 0; const size_t actual_seqlen_workspace_size = is_padding ? 2 * num_bytes_per_seqlen : 0;
const size_t num_bytes_per_ragged_offset = const size_t num_bytes_per_ragged_offset =
alignTo<16>((b + 1) * typeToSize(ragged_offset_type)); alignTo<16>(((b + 1) * typeToNumBits(ragged_offset_type)) / 8);
size_t seqlen_offsets_workspace_size = 0; size_t seqlen_offsets_workspace_size = 0;
if (is_ragged_q || is_ragged_kv) { if (is_ragged_q || is_ragged_kv) {
size_t count = 2 * (static_cast<size_t>(is_ragged_q) + static_cast<size_t>(is_ragged_kv)); size_t count = 2 * (static_cast<size_t>(is_ragged_q) + static_cast<size_t>(is_ragged_kv));
...@@ -957,9 +957,9 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( ...@@ -957,9 +957,9 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
size_t stride = 0; size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
stride = typeToSize(QKV_type) * num_attn_heads * head_dim; stride = (typeToNumBits(QKV_type) * num_attn_heads * head_dim) / 8;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) {
stride = typeToSize(QKV_type) * head_dim; stride = (typeToNumBits(QKV_type) * head_dim) / 8;
} }
void *devPtrQ = static_cast<void *>(devPtrQKV); void *devPtrQ = static_cast<void *>(devPtrQKV);
void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride); void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride);
...@@ -990,7 +990,7 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( ...@@ -990,7 +990,7 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
const auto cudnn_runtime_version = cudnnGetVersion(); const auto cudnn_runtime_version = cudnnGetVersion();
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Aux_CTX_Tensors->size = 3; Aux_CTX_Tensors->size = 3;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr; output_S->data.dptr = nullptr;
if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens, num_attn_heads, 1}; output_S->data.shape = {max_tokens, num_attn_heads, 1};
...@@ -998,17 +998,17 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( ...@@ -998,17 +998,17 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1}; output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1};
} }
output_S->data.dtype = DType::kFloat32; output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]); Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr; output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2}; output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64; output_rng_state->data.dtype = DType::kInt64;
Tensor *output_bias = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[2]); Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
output_bias->data.dptr = nullptr; output_bias->data.dptr = nullptr;
output_bias->data.shape = {bias_b, bias_h, max_seqlen, max_seqlen}; output_bias->data.shape = {bias_b, bias_h, max_seqlen, max_seqlen};
output_bias->data.dtype = QKV_type; output_bias->data.dtype = QKV_type;
} else { } else {
Aux_CTX_Tensors->size = 2; Aux_CTX_Tensors->size = 2;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr; output_S->data.dptr = nullptr;
if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens, num_attn_heads, 1}; output_S->data.shape = {max_tokens, num_attn_heads, 1};
...@@ -1016,22 +1016,22 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( ...@@ -1016,22 +1016,22 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1}; output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1};
} }
output_S->data.dtype = DType::kFloat32; output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]); Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr; output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2}; output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64; output_rng_state->data.dtype = DType::kInt64;
} }
} else if (Aux_CTX_Tensors->size == 2) { } else if (Aux_CTX_Tensors->size == 2) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr; devPtrS = output_S->data.dptr;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]); Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = rng_state->data.dptr; output_rng_state->data.dptr = rng_state->data.dptr;
} else if (Aux_CTX_Tensors->size == 3) { } else if (Aux_CTX_Tensors->size == 3) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr; devPtrS = output_S->data.dptr;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]); Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = rng_state->data.dptr; output_rng_state->data.dptr = rng_state->data.dptr;
Tensor *output_bias = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[2]); Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
output_bias->data.dptr = devPtrBias; output_bias->data.dptr = devPtrBias;
} else { } else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
...@@ -1082,9 +1082,9 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( ...@@ -1082,9 +1082,9 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
size_t stride = 0; size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
stride = typeToSize(QKV_type) * num_attn_heads * head_dim; stride = (typeToNumBits(QKV_type) * num_attn_heads * head_dim) / 8;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) {
stride = typeToSize(QKV_type) * head_dim; stride = (typeToNumBits(QKV_type) * head_dim) / 8;
} }
void *devPtrQ = devPtrQKV; void *devPtrQ = devPtrQKV;
void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride); void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride);
...@@ -1173,9 +1173,9 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( ...@@ -1173,9 +1173,9 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
size_t stride = 0; size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
stride = typeToSize(QKV_type) * num_gqa_groups * head_dim; stride = (typeToNumBits(QKV_type) * num_gqa_groups * head_dim) / 8;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) {
stride = typeToSize(QKV_type) * head_dim; stride = (typeToNumBits(QKV_type) * head_dim) / 8;
} }
void *devPtrK = devPtrKV; void *devPtrK = devPtrKV;
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrKV) + stride); void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrKV) + stride);
...@@ -1216,7 +1216,7 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( ...@@ -1216,7 +1216,7 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
const auto cudnn_runtime_version = cudnnGetVersion(); const auto cudnn_runtime_version = cudnnGetVersion();
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Aux_CTX_Tensors->size = 3; Aux_CTX_Tensors->size = 3;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr; output_S->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; output_S->data.shape = {max_tokens_q, num_attn_heads, 1};
...@@ -1224,17 +1224,17 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( ...@@ -1224,17 +1224,17 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
} }
output_S->data.dtype = DType::kFloat32; output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]); Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr; output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2}; output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64; output_rng_state->data.dtype = DType::kInt64;
Tensor *output_bias = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[2]); Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
output_bias->data.dptr = nullptr; output_bias->data.dptr = nullptr;
output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv}; output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv};
output_bias->data.dtype = QKV_type; output_bias->data.dtype = QKV_type;
} else { } else {
Aux_CTX_Tensors->size = 2; Aux_CTX_Tensors->size = 2;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr; output_S->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; output_S->data.shape = {max_tokens_q, num_attn_heads, 1};
...@@ -1242,22 +1242,22 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( ...@@ -1242,22 +1242,22 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
} }
output_S->data.dtype = DType::kFloat32; output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]); Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr; output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2}; output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64; output_rng_state->data.dtype = DType::kInt64;
} }
} else if (Aux_CTX_Tensors->size == 2) { } else if (Aux_CTX_Tensors->size == 2) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr; devPtrS = output_S->data.dptr;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]); Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = rng_state->data.dptr; output_rng_state->data.dptr = rng_state->data.dptr;
} else if (Aux_CTX_Tensors->size == 3) { } else if (Aux_CTX_Tensors->size == 3) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr; devPtrS = output_S->data.dptr;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]); Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = rng_state->data.dptr; output_rng_state->data.dptr = rng_state->data.dptr;
Tensor *output_bias = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[2]); Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
output_bias->data.dptr = devPtrBias; output_bias->data.dptr = devPtrBias;
} else { } else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
...@@ -1313,9 +1313,9 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( ...@@ -1313,9 +1313,9 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
size_t stride = 0; size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
stride = typeToSize(QKV_type) * num_gqa_groups * head_dim; stride = (typeToNumBits(QKV_type) * num_gqa_groups * head_dim) / 8;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) {
stride = typeToSize(QKV_type) * head_dim; stride = (typeToNumBits(QKV_type) * head_dim) / 8;
} }
void *devPtrK = devPtrKV; void *devPtrK = devPtrKV;
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrKV) + stride); void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrKV) + stride);
...@@ -1446,7 +1446,7 @@ void fused_attn_arbitrary_seqlen_fwd( ...@@ -1446,7 +1446,7 @@ void fused_attn_arbitrary_seqlen_fwd(
const auto cudnn_runtime_version = cudnnGetVersion(); const auto cudnn_runtime_version = cudnnGetVersion();
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Aux_CTX_Tensors->size = 3; Aux_CTX_Tensors->size = 3;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr; output_S->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; output_S->data.shape = {max_tokens_q, num_attn_heads, 1};
...@@ -1454,17 +1454,17 @@ void fused_attn_arbitrary_seqlen_fwd( ...@@ -1454,17 +1454,17 @@ void fused_attn_arbitrary_seqlen_fwd(
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
} }
output_S->data.dtype = DType::kFloat32; output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]); Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr; output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2}; output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64; output_rng_state->data.dtype = DType::kInt64;
Tensor *output_bias = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[2]); Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
output_bias->data.dptr = nullptr; output_bias->data.dptr = nullptr;
output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv}; output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv};
output_bias->data.dtype = QKV_type; output_bias->data.dtype = QKV_type;
} else { } else {
Aux_CTX_Tensors->size = 2; Aux_CTX_Tensors->size = 2;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr; output_S->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; output_S->data.shape = {max_tokens_q, num_attn_heads, 1};
...@@ -1472,22 +1472,22 @@ void fused_attn_arbitrary_seqlen_fwd( ...@@ -1472,22 +1472,22 @@ void fused_attn_arbitrary_seqlen_fwd(
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
} }
output_S->data.dtype = DType::kFloat32; output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]); Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr; output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2}; output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64; output_rng_state->data.dtype = DType::kInt64;
} }
} else if (Aux_CTX_Tensors->size == 2) { } else if (Aux_CTX_Tensors->size == 2) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr; devPtrS = output_S->data.dptr;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]); Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = rng_state->data.dptr; output_rng_state->data.dptr = rng_state->data.dptr;
} else if (Aux_CTX_Tensors->size == 3) { } else if (Aux_CTX_Tensors->size == 3) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr; devPtrS = output_S->data.dptr;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]); Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = rng_state->data.dptr; output_rng_state->data.dptr = rng_state->data.dptr;
Tensor *output_bias = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[2]); Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
output_bias->data.dptr = devPtrBias; output_bias->data.dptr = devPtrBias;
} else { } else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
......
...@@ -1239,12 +1239,12 @@ void fused_attn_max_512_fwd_qkvpacked( ...@@ -1239,12 +1239,12 @@ void fused_attn_max_512_fwd_qkvpacked(
if (Aux_CTX_Tensors->size == 0) { if (Aux_CTX_Tensors->size == 0) {
Aux_CTX_Tensors->size = 1; Aux_CTX_Tensors->size = 1;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr; output_S->data.dptr = nullptr;
output_S->data.shape = {batch, num_head, max_seqlen, max_seqlen}; output_S->data.shape = {batch, num_head, max_seqlen, max_seqlen};
output_S->data.dtype = input_QKV->data.dtype; output_S->data.dtype = input_QKV->data.dtype;
} else if (Aux_CTX_Tensors->size == 1) { } else if (Aux_CTX_Tensors->size == 1) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr; devPtrS = output_S->data.dptr;
} else { } else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
...@@ -1317,12 +1317,12 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t num_head, size_t q_max ...@@ -1317,12 +1317,12 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t num_head, size_t q_max
if (Aux_CTX_Tensors->size == 0) { if (Aux_CTX_Tensors->size == 0) {
Aux_CTX_Tensors->size = 1; Aux_CTX_Tensors->size = 1;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr; output_S->data.dptr = nullptr;
output_S->data.shape = {batch, num_head, q_max_seqlen, kv_max_seqlen}; output_S->data.shape = {batch, num_head, q_max_seqlen, kv_max_seqlen};
output_S->data.dtype = q_type; output_S->data.dtype = q_type;
} else if (Aux_CTX_Tensors->size == 1) { } else if (Aux_CTX_Tensors->size == 1) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr; devPtrS = output_S->data.dptr;
} else { } else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
...@@ -1386,12 +1386,12 @@ void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen, ...@@ -1386,12 +1386,12 @@ void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen,
if (Aux_CTX_Tensors->size == 0) { if (Aux_CTX_Tensors->size == 0) {
Aux_CTX_Tensors->size = 1; Aux_CTX_Tensors->size = 1;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr; output_S->data.dptr = nullptr;
output_S->data.shape = {batch, num_head, q_max_seqlen, kv_max_seqlen}; output_S->data.shape = {batch, num_head, q_max_seqlen, kv_max_seqlen};
output_S->data.dtype = q_type; output_S->data.dtype = q_type;
} else if (Aux_CTX_Tensors->size == 1) { } else if (Aux_CTX_Tensors->size == 1) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr; devPtrS = output_S->data.dptr;
} else { } else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
......
...@@ -2364,9 +2364,9 @@ void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t ma ...@@ -2364,9 +2364,9 @@ void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t ma
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
size_t stride = 0; size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
stride = typeToSize(QKV_type) * num_attn_heads * head_dim; stride = (typeToNumBits(QKV_type) * num_attn_heads * head_dim) / 8;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) {
stride = typeToSize(QKV_type) * head_dim; stride = (typeToNumBits(QKV_type) * head_dim) / 8;
} }
void* devPtrQ = static_cast<void*>(devPtrQKV); void* devPtrQ = static_cast<void*>(devPtrQKV);
void* devPtrK = static_cast<void*>(static_cast<int8_t*>(devPtrQKV) + stride); void* devPtrK = static_cast<void*>(static_cast<int8_t*>(devPtrQKV) + stride);
...@@ -2383,9 +2383,9 @@ void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t ma ...@@ -2383,9 +2383,9 @@ void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t ma
void* devPtrZInv = nullptr; void* devPtrZInv = nullptr;
if (Aux_CTX_Tensors->size == 0) { if (Aux_CTX_Tensors->size == 0) {
Aux_CTX_Tensors->size = 3; Aux_CTX_Tensors->size = 3;
Tensor* output_M = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[0]); Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Tensor* output_ZInv = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]); Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
Tensor* output_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[2]); Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
output_M->data.dptr = nullptr; output_M->data.dptr = nullptr;
output_M->data.shape = {batch, num_attn_heads, max_seqlen, 1}; output_M->data.shape = {batch, num_attn_heads, max_seqlen, 1};
output_M->data.dtype = DType::kFloat32; output_M->data.dtype = DType::kFloat32;
...@@ -2396,9 +2396,9 @@ void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t ma ...@@ -2396,9 +2396,9 @@ void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t ma
output_rng_state->data.shape = {2}; output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64; output_rng_state->data.dtype = DType::kInt64;
} else if (Aux_CTX_Tensors->size == 3) { } else if (Aux_CTX_Tensors->size == 3) {
Tensor* output_M = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[0]); Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Tensor* output_ZInv = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]); Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
Tensor* output_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[2]); Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
devPtrM = output_M->data.dptr; devPtrM = output_M->data.dptr;
devPtrZInv = output_ZInv->data.dptr; devPtrZInv = output_ZInv->data.dptr;
output_rng_state->data.dptr = rng_state->data.dptr; output_rng_state->data.dptr = rng_state->data.dptr;
...@@ -2466,9 +2466,9 @@ void fused_attn_fp8_bwd_qkvpacked( ...@@ -2466,9 +2466,9 @@ void fused_attn_fp8_bwd_qkvpacked(
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
size_t stride = 0; size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
stride = typeToSize(QKV_type) * num_attn_heads * head_dim; stride = (typeToNumBits(QKV_type) * num_attn_heads * head_dim) / 8;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) {
stride = typeToSize(QKV_type) * head_dim; stride = (typeToNumBits(QKV_type) * head_dim) / 8;
} }
void* devPtrQ = devPtrQKV; void* devPtrQ = devPtrQKV;
void* devPtrK = static_cast<void*>(static_cast<int8_t*>(devPtrQKV) + stride); void* devPtrK = static_cast<void*>(static_cast<int8_t*>(devPtrQKV) + stride);
...@@ -2564,9 +2564,9 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num ...@@ -2564,9 +2564,9 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
size_t stride = 0; size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
stride = typeToSize(QKV_type) * num_gqa_groups * head_dim; stride = (typeToNumBits(QKV_type) * num_gqa_groups * head_dim) / 8;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) {
stride = typeToSize(QKV_type) * head_dim; stride = (typeToNumBits(QKV_type) * head_dim) / 8;
} }
void* devPtrK = devPtrKV; void* devPtrK = devPtrKV;
void* devPtrV = static_cast<void*>(static_cast<int8_t*>(devPtrKV) + stride); void* devPtrV = static_cast<void*>(static_cast<int8_t*>(devPtrKV) + stride);
...@@ -2582,9 +2582,9 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num ...@@ -2582,9 +2582,9 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num
void* devPtrZInv = nullptr; void* devPtrZInv = nullptr;
if (Aux_CTX_Tensors->size == 0) { if (Aux_CTX_Tensors->size == 0) {
Aux_CTX_Tensors->size = 3; Aux_CTX_Tensors->size = 3;
Tensor* output_M = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[0]); Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Tensor* output_ZInv = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]); Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
Tensor* output_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[2]); Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
output_M->data.dptr = nullptr; output_M->data.dptr = nullptr;
output_M->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; output_M->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
output_M->data.dtype = DType::kFloat32; output_M->data.dtype = DType::kFloat32;
...@@ -2595,9 +2595,9 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num ...@@ -2595,9 +2595,9 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num
output_rng_state->data.shape = {2}; output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64; output_rng_state->data.dtype = DType::kInt64;
} else if (Aux_CTX_Tensors->size == 3) { } else if (Aux_CTX_Tensors->size == 3) {
Tensor* output_M = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[0]); Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Tensor* output_ZInv = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]); Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
Tensor* output_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[2]); Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
devPtrM = output_M->data.dptr; devPtrM = output_M->data.dptr;
devPtrZInv = output_ZInv->data.dptr; devPtrZInv = output_ZInv->data.dptr;
output_rng_state->data.dptr = rng_state->data.dptr; output_rng_state->data.dptr = rng_state->data.dptr;
...@@ -2671,9 +2671,9 @@ void fused_attn_fp8_bwd_kvpacked( ...@@ -2671,9 +2671,9 @@ void fused_attn_fp8_bwd_kvpacked(
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
size_t stride = 0; size_t stride = 0;
if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
stride = typeToSize(QKV_type) * num_gqa_groups * head_dim; stride = (typeToNumBits(QKV_type) * num_gqa_groups * head_dim) / 8;
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) {
stride = typeToSize(QKV_type) * head_dim; stride = (typeToNumBits(QKV_type) * head_dim) / 8;
} }
void* devPtrK = devPtrKV; void* devPtrK = devPtrKV;
void* devPtrV = static_cast<void*>(static_cast<int8_t*>(devPtrKV) + stride); void* devPtrV = static_cast<void*>(static_cast<int8_t*>(devPtrKV) + stride);
...@@ -2779,9 +2779,9 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou ...@@ -2779,9 +2779,9 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou
void* devPtrZInv = nullptr; void* devPtrZInv = nullptr;
if (Aux_CTX_Tensors->size == 0) { if (Aux_CTX_Tensors->size == 0) {
Aux_CTX_Tensors->size = 3; Aux_CTX_Tensors->size = 3;
Tensor* output_M = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[0]); Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Tensor* output_ZInv = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]); Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
Tensor* output_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[2]); Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
output_M->data.dptr = nullptr; output_M->data.dptr = nullptr;
output_M->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; output_M->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
output_M->data.dtype = DType::kFloat32; output_M->data.dtype = DType::kFloat32;
...@@ -2792,9 +2792,9 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou ...@@ -2792,9 +2792,9 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou
output_rng_state->data.shape = {2}; output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64; output_rng_state->data.dtype = DType::kInt64;
} else if (Aux_CTX_Tensors->size == 3) { } else if (Aux_CTX_Tensors->size == 3) {
Tensor* output_M = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[0]); Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Tensor* output_ZInv = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]); Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
Tensor* output_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[2]); Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
devPtrM = output_M->data.dptr; devPtrM = output_M->data.dptr;
devPtrZInv = output_ZInv->data.dptr; devPtrZInv = output_ZInv->data.dptr;
output_rng_state->data.dptr = rng_state->data.dptr; output_rng_state->data.dptr = rng_state->data.dptr;
......
...@@ -260,12 +260,12 @@ void nvte_copy_to_kv_cache(NVTETensor new_k, NVTETensor new_v, NVTETensor k_cach ...@@ -260,12 +260,12 @@ void nvte_copy_to_kv_cache(NVTETensor new_k, NVTETensor new_v, NVTETensor k_cach
NVTE_API_CALL(nvte_copy_to_kv_cache); NVTE_API_CALL(nvte_copy_to_kv_cache);
using namespace transformer_engine; using namespace transformer_engine;
kv_cache::copy_to_kv_cache( kv_cache::copy_to_kv_cache(*convertNVTETensorCheck(new_k), *convertNVTETensorCheck(new_v),
*reinterpret_cast<Tensor *>(new_k), *reinterpret_cast<Tensor *>(new_v), *convertNVTETensorCheck(k_cache), *convertNVTETensorCheck(v_cache),
*reinterpret_cast<Tensor *>(k_cache), *reinterpret_cast<Tensor *>(v_cache), *convertNVTETensorCheck(page_table),
*reinterpret_cast<Tensor *>(page_table), *reinterpret_cast<Tensor *>(cu_new_lens), *convertNVTETensorCheck(cu_new_lens),
*reinterpret_cast<Tensor *>(cu_cached_lens), qkv_format, b, max_ctx_len, max_seq_len, *convertNVTETensorCheck(cu_cached_lens), qkv_format, b, max_ctx_len,
max_pages_per_seq, is_non_paged, stream); max_seq_len, max_pages_per_seq, is_non_paged, stream);
} }
/*************************************************************************************************** /***************************************************************************************************
...@@ -277,9 +277,9 @@ void nvte_convert_thd_to_bshd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETens ...@@ -277,9 +277,9 @@ void nvte_convert_thd_to_bshd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETens
NVTE_API_CALL(nvte_convert_thd_to_bshd); NVTE_API_CALL(nvte_convert_thd_to_bshd);
using namespace transformer_engine; using namespace transformer_engine;
kv_cache::convert_thd_to_bshd(*reinterpret_cast<Tensor *>(tensor), kv_cache::convert_thd_to_bshd(*convertNVTETensorCheck(tensor),
*reinterpret_cast<Tensor *>(cu_seqlens), *convertNVTETensorCheck(cu_seqlens),
*reinterpret_cast<Tensor *>(new_tensor), b, max_seq_len, stream); *convertNVTETensorCheck(new_tensor), b, max_seq_len, stream);
} }
/*************************************************************************************************** /***************************************************************************************************
...@@ -291,7 +291,7 @@ void nvte_convert_bshd_to_thd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETens ...@@ -291,7 +291,7 @@ void nvte_convert_bshd_to_thd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETens
NVTE_API_CALL(nvte_convert_bshd_to_thd); NVTE_API_CALL(nvte_convert_bshd_to_thd);
using namespace transformer_engine; using namespace transformer_engine;
kv_cache::convert_bshd_to_thd(*reinterpret_cast<Tensor *>(tensor), kv_cache::convert_bshd_to_thd(*convertNVTETensorCheck(tensor),
*reinterpret_cast<Tensor *>(cu_seqlens), *convertNVTETensorCheck(cu_seqlens),
*reinterpret_cast<Tensor *>(new_tensor), t, stream); *convertNVTETensorCheck(new_tensor), t, stream);
} }
...@@ -308,11 +308,10 @@ void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens ...@@ -308,11 +308,10 @@ void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens
const int stride_d, cudaStream_t stream) { const int stride_d, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_rope_forward); NVTE_API_CALL(nvte_fused_rope_forward);
using namespace transformer_engine; using namespace transformer_engine;
fused_rope_forward( fused_rope_forward(*convertNVTETensorCheck(input), *convertNVTETensorCheck(cu_seqlens),
*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(cu_seqlens), *convertNVTETensorCheck(freqs), *convertNVTETensorCheck(start_positions),
*reinterpret_cast<const Tensor *>(freqs), *reinterpret_cast<const Tensor *>(start_positions), convertNVTETensorCheck(output), qkv_format, interleaved, cp_size, cp_rank, s,
reinterpret_cast<Tensor *>(output), qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, b, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, stream);
stride_s_or_t, stride_b, stride_h, stride_d, stream);
} }
void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens, void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens,
...@@ -324,9 +323,8 @@ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu ...@@ -324,9 +323,8 @@ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_rope_backward); NVTE_API_CALL(nvte_fused_rope_backward);
using namespace transformer_engine; using namespace transformer_engine;
fused_rope_backward(*reinterpret_cast<const Tensor *>(output_grads), fused_rope_backward(*convertNVTETensorCheck(output_grads), *convertNVTETensorCheck(cu_seqlens),
*reinterpret_cast<const Tensor *>(cu_seqlens), *convertNVTETensorCheck(freqs), convertNVTETensorCheck(input_grads),
*reinterpret_cast<const Tensor *>(freqs), qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t,
reinterpret_cast<Tensor *>(input_grads), qkv_format, interleaved, cp_size, stride_b, stride_h, stride_d, stream);
cp_rank, s, b, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, stream);
} }
...@@ -551,8 +551,8 @@ void nvte_scaled_aligned_causal_masked_softmax_forward(const NVTETensor input, ...@@ -551,8 +551,8 @@ void nvte_scaled_aligned_causal_masked_softmax_forward(const NVTETensor input,
float scale_factor, cudaStream_t stream) { float scale_factor, cudaStream_t stream) {
NVTE_API_CALL(nvte_scaled_aligned_causal_masked_softmax_forward); NVTE_API_CALL(nvte_scaled_aligned_causal_masked_softmax_forward);
using namespace transformer_engine; using namespace transformer_engine;
scaled_aligned_causal_masked_softmax_forward(*reinterpret_cast<const Tensor *>(input), scaled_aligned_causal_masked_softmax_forward(*convertNVTETensorCheck(input),
reinterpret_cast<Tensor *>(softmax_results), convertNVTETensorCheck(softmax_results),
scale_factor, stream); scale_factor, stream);
} }
...@@ -563,6 +563,6 @@ void nvte_scaled_aligned_causal_masked_softmax_backward(const NVTETensor incomin ...@@ -563,6 +563,6 @@ void nvte_scaled_aligned_causal_masked_softmax_backward(const NVTETensor incomin
NVTE_API_CALL(nvte_scaled_aligned_causal_masked_softmax_backward); NVTE_API_CALL(nvte_scaled_aligned_causal_masked_softmax_backward);
using namespace transformer_engine; using namespace transformer_engine;
scaled_aligned_causal_masked_softmax_backward( scaled_aligned_causal_masked_softmax_backward(
*reinterpret_cast<Tensor *>(output_grads), *reinterpret_cast<const Tensor *>(incoming_grads), *convertNVTETensorCheck(output_grads), *convertNVTETensorCheck(incoming_grads),
*reinterpret_cast<const Tensor *>(softmax_results), scale_factor, stream); *convertNVTETensorCheck(softmax_results), scale_factor, stream);
} }
...@@ -815,8 +815,8 @@ void nvte_scaled_softmax_forward(const NVTETensor input, NVTETensor softmax_resu ...@@ -815,8 +815,8 @@ void nvte_scaled_softmax_forward(const NVTETensor input, NVTETensor softmax_resu
float scale_factor, cudaStream_t stream) { float scale_factor, cudaStream_t stream) {
NVTE_API_CALL(nvte_scaled_softmax_forward); NVTE_API_CALL(nvte_scaled_softmax_forward);
using namespace transformer_engine; using namespace transformer_engine;
scaled_softmax_forward(*reinterpret_cast<const Tensor *>(input), scaled_softmax_forward(*convertNVTETensorCheck(input), convertNVTETensorCheck(softmax_results),
reinterpret_cast<Tensor *>(softmax_results), scale_factor, stream); scale_factor, stream);
} }
void nvte_scaled_softmax_backward(const NVTETensor incoming_grads, const NVTETensor softmax_results, void nvte_scaled_softmax_backward(const NVTETensor incoming_grads, const NVTETensor softmax_results,
...@@ -824,9 +824,9 @@ void nvte_scaled_softmax_backward(const NVTETensor incoming_grads, const NVTETen ...@@ -824,9 +824,9 @@ void nvte_scaled_softmax_backward(const NVTETensor incoming_grads, const NVTETen
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_scaled_softmax_backward); NVTE_API_CALL(nvte_scaled_softmax_backward);
using namespace transformer_engine; using namespace transformer_engine;
scaled_softmax_backward(*reinterpret_cast<Tensor *>(output_grads), scaled_softmax_backward(*convertNVTETensorCheck(output_grads),
*reinterpret_cast<const Tensor *>(incoming_grads), *convertNVTETensorCheck(incoming_grads),
*reinterpret_cast<const Tensor *>(softmax_results), scale_factor, stream); *convertNVTETensorCheck(softmax_results), scale_factor, stream);
} }
void nvte_scaled_masked_softmax_forward(const NVTETensor input, const NVTETensor mask, void nvte_scaled_masked_softmax_forward(const NVTETensor input, const NVTETensor mask,
...@@ -834,9 +834,8 @@ void nvte_scaled_masked_softmax_forward(const NVTETensor input, const NVTETensor ...@@ -834,9 +834,8 @@ void nvte_scaled_masked_softmax_forward(const NVTETensor input, const NVTETensor
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_scaled_masked_softmax_forward); NVTE_API_CALL(nvte_scaled_masked_softmax_forward);
using namespace transformer_engine; using namespace transformer_engine;
scaled_masked_softmax_forward(*reinterpret_cast<const Tensor *>(input), scaled_masked_softmax_forward(*convertNVTETensorCheck(input), *convertNVTETensorCheck(mask),
*reinterpret_cast<const Tensor *>(mask), convertNVTETensorCheck(softmax_results), scale_factor, stream);
reinterpret_cast<Tensor *>(softmax_results), scale_factor, stream);
} }
void nvte_scaled_masked_softmax_backward(const NVTETensor incoming_grads, void nvte_scaled_masked_softmax_backward(const NVTETensor incoming_grads,
...@@ -844,7 +843,7 @@ void nvte_scaled_masked_softmax_backward(const NVTETensor incoming_grads, ...@@ -844,7 +843,7 @@ void nvte_scaled_masked_softmax_backward(const NVTETensor incoming_grads,
float scale_factor, cudaStream_t stream) { float scale_factor, cudaStream_t stream) {
NVTE_API_CALL(nvte_scaled_masked_softmax_backward); NVTE_API_CALL(nvte_scaled_masked_softmax_backward);
using namespace transformer_engine; using namespace transformer_engine;
scaled_masked_softmax_backward( scaled_masked_softmax_backward(*convertNVTETensorCheck(output_grads),
*reinterpret_cast<Tensor *>(output_grads), *reinterpret_cast<const Tensor *>(incoming_grads), *convertNVTETensorCheck(incoming_grads),
*reinterpret_cast<const Tensor *>(softmax_results), scale_factor, stream); *convertNVTETensorCheck(softmax_results), scale_factor, stream);
} }
...@@ -599,9 +599,9 @@ void nvte_scaled_upper_triang_masked_softmax_forward(const NVTETensor input, ...@@ -599,9 +599,9 @@ void nvte_scaled_upper_triang_masked_softmax_forward(const NVTETensor input,
NVTETensor softmax_results, float scale_factor, NVTETensor softmax_results, float scale_factor,
cudaStream_t stream) { cudaStream_t stream) {
using namespace transformer_engine; using namespace transformer_engine;
scaled_upper_triang_masked_softmax_forward(*reinterpret_cast<const Tensor *>(input), scaled_upper_triang_masked_softmax_forward(*convertNVTETensorCheck(input),
reinterpret_cast<Tensor *>(softmax_results), convertNVTETensorCheck(softmax_results), scale_factor,
scale_factor, stream); stream);
} }
void nvte_scaled_upper_triang_masked_softmax_backward(const NVTETensor incoming_grads, void nvte_scaled_upper_triang_masked_softmax_backward(const NVTETensor incoming_grads,
...@@ -610,6 +610,6 @@ void nvte_scaled_upper_triang_masked_softmax_backward(const NVTETensor incoming_ ...@@ -610,6 +610,6 @@ void nvte_scaled_upper_triang_masked_softmax_backward(const NVTETensor incoming_
cudaStream_t stream) { cudaStream_t stream) {
using namespace transformer_engine; using namespace transformer_engine;
scaled_upper_triang_masked_softmax_backward( scaled_upper_triang_masked_softmax_backward(
*reinterpret_cast<Tensor *>(output_grads), *reinterpret_cast<const Tensor *>(incoming_grads), *convertNVTETensorCheck(output_grads), *convertNVTETensorCheck(incoming_grads),
*reinterpret_cast<const Tensor *>(softmax_results), scale_factor, stream); *convertNVTETensorCheck(softmax_results), scale_factor, stream);
} }
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "rocm_gemm.hip" #include "rocm_gemm.hip"
#endif // #ifndef __HIP_PLATFORM_AMD__ #endif // #ifndef __HIP_PLATFORM_AMD__
#include <transformer_engine/gemm.h> #include <transformer_engine/gemm.h>
#include <transformer_engine/multi_stream.h>
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
#include <cstdint> #include <cstdint>
...@@ -22,6 +23,7 @@ ...@@ -22,6 +23,7 @@
#include "../common.h" #include "../common.h"
#include "../util/handle_manager.h" #include "../util/handle_manager.h"
#include "../util/logging.h" #include "../util/logging.h"
#include "../util/multi_stream.h"
#include "common/util/cuda_runtime.h" #include "common/util/cuda_runtime.h"
#ifndef __HIP_PLATFORM_AMD__ #ifndef __HIP_PLATFORM_AMD__
...@@ -94,7 +96,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ...@@ -94,7 +96,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
A.scaling_mode == B.scaling_mode || A.scaling_mode == B.scaling_mode ||
(A.scaling_mode == NVTE_BLOCK_SCALING_1D && B.scaling_mode == NVTE_BLOCK_SCALING_2D) || (A.scaling_mode == NVTE_BLOCK_SCALING_1D && B.scaling_mode == NVTE_BLOCK_SCALING_2D) ||
(A.scaling_mode == NVTE_BLOCK_SCALING_2D && B.scaling_mode == NVTE_BLOCK_SCALING_1D), (A.scaling_mode == NVTE_BLOCK_SCALING_2D && B.scaling_mode == NVTE_BLOCK_SCALING_1D),
"Inputs A and B to GEMM need to have compatible scaling modes!"); "Inputs A and B to GEMM need to have compatible scaling modes, but got A.scaling_mode = " +
to_string(A.scaling_mode) + ", B.scaling_mode = " + to_string(B.scaling_mode));
NVTE_CHECK(A.has_data() || A.has_columnwise_data(), "Input A does not hold any data!"); NVTE_CHECK(A.has_data() || A.has_columnwise_data(), "Input A does not hold any data!");
NVTE_CHECK(B.has_data() || B.has_columnwise_data(), "Input B does not hold any data!"); NVTE_CHECK(B.has_data() || B.has_columnwise_data(), "Input B does not hold any data!");
GemmParam ret; GemmParam ret;
...@@ -507,7 +510,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -507,7 +510,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE,
&epilogue, sizeof(epilogue))); &epilogue, sizeof(epilogue)));
#if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205 #if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205 && CUDA_VERSION < 13000 && \
CUBLAS_VERSION < 130000
if (counter != nullptr) { if (counter != nullptr) {
if (m_split == 0) m_split = 1; if (m_split == 0) m_split = 1;
if (n_split == 0) n_split = 1; if (n_split == 0) n_split = 1;
...@@ -536,6 +540,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -536,6 +540,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
const auto B_alignment = _getAlignment(reinterpret_cast<uintptr_t>(param.B)); const auto B_alignment = _getAlignment(reinterpret_cast<uintptr_t>(param.B));
const auto C_alignment = _getAlignment(reinterpret_cast<uintptr_t>(C)); const auto C_alignment = _getAlignment(reinterpret_cast<uintptr_t>(C));
const auto D_alignment = _getAlignment(reinterpret_cast<uintptr_t>(D)); const auto D_alignment = _getAlignment(reinterpret_cast<uintptr_t>(D));
const auto workspace_alignment = _getAlignment(reinterpret_cast<uintptr_t>(workspace));
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, &A_alignment, sizeof(A_alignment))); preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, &A_alignment, sizeof(A_alignment)));
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
...@@ -544,6 +549,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -544,6 +549,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, &C_alignment, sizeof(C_alignment))); preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, &C_alignment, sizeof(C_alignment)));
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, &D_alignment, sizeof(D_alignment))); preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, &D_alignment, sizeof(D_alignment)));
NVTE_CHECK(workspace_alignment % 256 == 0,
"cuBLAS workspace pointer must be aligned to 256 bytes, got ", workspace_alignment);
const auto status = const auto status =
cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference, cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference,
...@@ -582,18 +589,6 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -582,18 +589,6 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
} }
#endif // __HIP_PLATFORM_AMD__ #endif // __HIP_PLATFORM_AMD__
static std::once_flag init_flag;
static cudaStream_t compute_streams[num_streams];
static cudaEvent_t cublas_event[num_streams];
// Warning: only call once per device!
static void init_streams_and_events() {
for (int i = 0; i < num_streams; i++) {
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&compute_streams[i], cudaStreamNonBlocking, -1));
NVTE_CHECK_CUDA(cudaEventCreate(&cublas_event[i]));
}
}
// Add for batchgemm // Add for batchgemm
static std::once_flag init_flag_batchgemm; static std::once_flag init_flag_batchgemm;
static cudaStream_t compute_streams_batchgemm[num_batchgemm_streams]; static cudaStream_t compute_streams_batchgemm[num_batchgemm_streams];
...@@ -615,12 +610,12 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons ...@@ -615,12 +610,12 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
int math_sm_count, cudaStream_t stream, bool nvte_use_hipblaslt, bool nvte_use_rocblas, int compute_stream_offset) { int math_sm_count, cudaStream_t stream, bool nvte_use_hipblaslt, bool nvte_use_rocblas, int compute_stream_offset) {
NVTE_API_CALL(nvte_cublas_gemm); NVTE_API_CALL(nvte_cublas_gemm);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *inputA = reinterpret_cast<const Tensor *>(A); const Tensor *inputA = convertNVTETensorCheck(A);
const Tensor *inputB = reinterpret_cast<const Tensor *>(B); const Tensor *inputB = convertNVTETensorCheck(B);
Tensor *outputD = reinterpret_cast<Tensor *>(D); Tensor *outputD = convertNVTETensor(D);
const Tensor *biasTensor = reinterpret_cast<const Tensor *>(bias); const Tensor *biasTensor = convertNVTETensor(bias);
Tensor *outputGelu = reinterpret_cast<Tensor *>(pre_gelu_out); Tensor *outputGelu = convertNVTETensor(pre_gelu_out);
Tensor *wspace = reinterpret_cast<Tensor *>(workspace); Tensor *wspace = convertNVTETensor(workspace);
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
const size_t A0 = inputA->flat_first_dim(); const size_t A0 = inputA->flat_first_dim();
...@@ -693,18 +688,19 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor ...@@ -693,18 +688,19 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
#ifndef __HIP_PLATFORM_AMD__ #ifndef __HIP_PLATFORM_AMD__
int cudart_version; int cudart_version;
NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&cudart_version)); NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&cudart_version));
NVTE_CHECK(cudart_version >= 12020, "Cuda version 12.2 is required for atomic gemm."); NVTE_CHECK(cudart_version >= 12020 && cudart_version < 13000,
NVTE_CHECK(cublasLtGetVersion() >= 120205, "Cublas version 12.2.5 is required for atomic gemm."); "Cuda version >=12.2 and <13.0 is required for atomic gemm.");
#endif NVTE_CHECK(cublasLtGetVersion() >= 120205 && cublasLtGetVersion() < 130000,
"Cublas version >=12.2.5 and <13.0 is required for atomic gemm.");
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *inputA = reinterpret_cast<const Tensor *>(A); const Tensor *inputA = convertNVTETensorCheck(A);
const Tensor *inputB = reinterpret_cast<const Tensor *>(B); const Tensor *inputB = convertNVTETensorCheck(B);
Tensor *outputD = reinterpret_cast<Tensor *>(D); Tensor *outputD = convertNVTETensor(D);
const Tensor *biasTensor = reinterpret_cast<const Tensor *>(bias); const Tensor *biasTensor = convertNVTETensor(bias);
Tensor *outputGelu = reinterpret_cast<Tensor *>(pre_gelu_out); Tensor *outputGelu = convertNVTETensor(pre_gelu_out);
const Tensor *inputCounter = reinterpret_cast<const Tensor *>(counter); const Tensor *inputCounter = convertNVTETensor(counter);
Tensor *wspace = reinterpret_cast<Tensor *>(workspace); Tensor *wspace = convertNVTETensor(workspace);
NVTE_CHECK(is_delayed_tensor_scaling(inputA->scaling_mode) && NVTE_CHECK(is_delayed_tensor_scaling(inputA->scaling_mode) &&
is_delayed_tensor_scaling(inputB->scaling_mode), is_delayed_tensor_scaling(inputB->scaling_mode),
...@@ -775,14 +771,15 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT ...@@ -775,14 +771,15 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_stream_cublas_gemm); NVTE_API_CALL(nvte_multi_stream_cublas_gemm);
using namespace transformer_engine; using namespace transformer_engine;
// Inits streams and events (once, globally)
std::call_once(init_flag, init_streams_and_events); int num_streams = nvte_get_num_compute_streams();
int num_stream_used = std::min(num_streams, num_gemms); int num_stream_used = std::min(num_streams, num_gemms);
// wait for current stream to finish // wait for current stream to finish
NVTE_CHECK_CUDA(cudaEventRecord(cublas_event[0], stream)); NVTE_CHECK_CUDA(cudaEventRecord(detail::get_compute_stream_event(0), stream));
for (int s = 0; s < num_stream_used; s++) { for (int s = 0; s < num_stream_used; s++) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent(compute_streams[s], cublas_event[0])); NVTE_CHECK_CUDA(
cudaStreamWaitEvent(detail::get_compute_stream(s), detail::get_compute_stream_event(0)));
} }
const char *NVTE_BLAS_MULSTREAM = std::getenv("NVTE_FORCE_BLAS_MULSTREAM"); const char *NVTE_BLAS_MULSTREAM = std::getenv("NVTE_FORCE_BLAS_MULSTREAM");
const char *NVTE_FORCE_ROCM_GEMM = std::getenv("NVTE_FORCE_ROCM_GEMM"); const char *NVTE_FORCE_ROCM_GEMM = std::getenv("NVTE_FORCE_ROCM_GEMM");
...@@ -798,23 +795,24 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT ...@@ -798,23 +795,24 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT
for (int i = 0; i < num_gemms; i++) { for (int i = 0; i < num_gemms; i++) {
nvte_cublas_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad, nvte_cublas_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad,
workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count, workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count,
compute_streams[i % num_streams]); detail::get_compute_stream(i % num_streams));
} }
} else{ } else{
for (int i = 0; i < num_gemms; i++) { for (int i = 0; i < num_gemms; i++) {
nvte_cublas_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad, nvte_cublas_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad,
workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count, workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count,
compute_streams[i % num_streams], 1, 0, i % num_streams); detail::get_compute_stream(i % num_streams), 1, 0, i % num_streams);
} }
} }
// record events on compute streams // record events on compute streams
for (int s = 0; s < num_stream_used; s++) { for (int s = 0; s < num_stream_used; s++) {
NVTE_CHECK_CUDA(cudaEventRecord(cublas_event[s], compute_streams[s])); NVTE_CHECK_CUDA(
cudaEventRecord(detail::get_compute_stream_event(s), detail::get_compute_stream(s)));
} }
// wait for all compute streams to finish // wait for all compute streams to finish
for (int s = 0; s < num_stream_used; s++) { for (int s = 0; s < num_stream_used; s++) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, cublas_event[s])); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, detail::get_compute_stream_event(s)));
} }
} }
......
...@@ -259,6 +259,17 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor act_inp ...@@ -259,6 +259,17 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor act_inp
*/ */
void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream); void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Casts multiple input tensors to quantized output tensors.
*
* \param[in] inputs List of input tensors to be cast.
* \param[in,out] outputs List of output quantized tensors.
* \param[in] quant_config (Optional) Quantization configurations.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs,
const NVTEQuantizationConfig quant_config, const size_t num_tensors,
cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
......
...@@ -17,23 +17,21 @@ ...@@ -17,23 +17,21 @@
extern "C" { extern "C" {
#endif #endif
/*! \brief Transposes the input, providing the option to immediately exit the kernel /*! \brief Transposes the input.
* based on the value of the 'noop' tensor.
* *
* \param[in] input Input tensor. * \param[in] input Input tensor to be cast.
* \param[in] noop Noop tensor. * \param[in] noop If this single element tensor has non-zero value, kernel will exit immediately.
* \param[in,out] output Output tensor. * \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output, void nvte_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Casts and transposes the input, providing the option to immediately exit the kernel /*! \brief Casts and transposes the input.
* based on the value of the 'noop' tensor.
* *
* \param[in] input Input tensor. * \param[in] input Input tensor to be cast.
* \param[in] noop Noop tensor. * \param[in] noop If this single element tensor has non-zero value, kernel will exit immediately.
* \param[in,out] output Output tensor. * \param[in,out] output Output quantized tensor.
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_cast_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output, void nvte_cast_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output,
......
...@@ -172,6 +172,7 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); ...@@ -172,6 +172,7 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
/*! \brief Get fused attention backend based on input parameters. /*! \brief Get fused attention backend based on input parameters.
* *
* \param[in] is_training Whether the model is in training mode.
* \param[in] q_dtype The data type of Tensor Q. * \param[in] q_dtype The data type of Tensor Q.
* \param[in] kv_dtype The data type of Tensors K, V. * \param[in] kv_dtype The data type of Tensors K, V.
* \param[in] qkv_layout The layout of Tensors Q, K, V. * \param[in] qkv_layout The layout of Tensors Q, K, V.
...@@ -188,10 +189,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); ...@@ -188,10 +189,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
* \param[in] window_size_right Sliding window size (the right half). * \param[in] window_size_right Sliding window size (the right half).
*/ */
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout,
NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk,
int64_t window_size_left, int64_t window_size_right); size_t head_dim_v, int64_t window_size_left, int64_t window_size_right);
/*! \brief Compute dot product attention with packed QKV input. /*! \brief Compute dot product attention with packed QKV input.
* *
...@@ -580,6 +581,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -580,6 +581,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
cudaStream_t stream); cudaStream_t stream);
/*! \brief Update the RNG state with the seed and calculated offset. /*! \brief Update the RNG state with the seed and calculated offset.
*
* \warning This API is **experimental** and subject to change.
* *
* \param[in] rng_state_dst RNG state to store seed and offset. * \param[in] rng_state_dst RNG state to store seed and offset.
* \param[in] seed Seed for RNG state. * \param[in] seed Seed for RNG state.
...@@ -595,6 +598,8 @@ void nvte_populate_rng_state_async(NVTETensor rng_state_dst, const NVTETensor se ...@@ -595,6 +598,8 @@ void nvte_populate_rng_state_async(NVTETensor rng_state_dst, const NVTETensor se
NVTE_Fused_Attn_Backend backend, cudaStream_t stream); NVTE_Fused_Attn_Backend backend, cudaStream_t stream);
/*! \brief Get KV format for a given QKV layout. /*! \brief Get KV format for a given QKV layout.
*
* \warning This API is **experimental** and subject to change.
* *
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1]. * \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[in] workspace Workspace tensor. * \param[in] workspace Workspace tensor.
...@@ -604,48 +609,187 @@ void nvte_populate_rng_state_async(NVTETensor rng_state_dst, const NVTETensor se ...@@ -604,48 +609,187 @@ void nvte_populate_rng_state_async(NVTETensor rng_state_dst, const NVTETensor se
uint32_t nvte_get_runtime_num_segments(NVTETensor cu_seqlen, NVTETensor workspace, size_t len, uint32_t nvte_get_runtime_num_segments(NVTETensor cu_seqlen, NVTETensor workspace, size_t len,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Set the seed and offset for RNG state.
*
* \warning This API is **experimental** and subject to change.
*
* \param[out] rng_state_ptr A size 2 array storing the RNG's seed and offset respectively.
* \param[in] captured Whether a CUDA graph is being captured.
* \param[in] seed_ptr Seed pointer.
* \param[in] seed_val Seed value.
* \param[in] offset_ptr Offset pointer.
* \param[in] offset_val Offset value.
* \param[in] offset_intragraph Intragraph offset in RNG states. For use with CUDA Graphs.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_extract_seed_and_offset(int64_t *rng_state_ptr, int captured, int64_t *seed_ptr, void nvte_extract_seed_and_offset(int64_t *rng_state_ptr, int captured, int64_t *seed_ptr,
uint64_t seed_val, int64_t *offset_ptr, uint64_t offset_val, uint64_t seed_val, int64_t *offset_ptr, uint64_t offset_val,
uint32_t offset_intragraph, cudaStream_t stream); uint32_t offset_intragraph, cudaStream_t stream);
/*! \brief Copy keys and values into the KV cache.
*
* \warning This API is **experimental** and subject to change.
*
* \param[in] new_k Key tensor.
* \param[in] new_v Value tensor.
* \param[out] k_cache Key cache.
* \param[out] v_cache Value cache.
* \param[in] page_table Page table for K cache, [batch_size, max_pages_per_seq].
* \param[in] cu_new_lens Cumulative sequence lengths.
* \param[in] cu_cached_lens Cached cumulative sequence lengths.
* \param[in] qkv_format QKV format, e.g. sbhd.
* \param[in] b Batch size.
* \param[in] max_ctx_len Maximum context length.
* \param[in] max_seq_len Maximum sequence length.
* \param[in] max_pages_per_seq Maximum number of pages per sequence.
* \param[in] is_non_paged Whether the cache is paged or not.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_copy_to_kv_cache(NVTETensor new_k, NVTETensor new_v, NVTETensor k_cache, void nvte_copy_to_kv_cache(NVTETensor new_k, NVTETensor new_v, NVTETensor k_cache,
NVTETensor v_cache, NVTETensor page_table, NVTETensor cu_new_lens, NVTETensor v_cache, NVTETensor page_table, NVTETensor cu_new_lens,
NVTETensor cu_cached_lens, NVTE_QKV_Format qkv_format, int b, NVTETensor cu_cached_lens, NVTE_QKV_Format qkv_format, int b,
int max_ctx_len, int max_seq_len, int max_pages_per_seq, int max_ctx_len, int max_seq_len, int max_pages_per_seq,
int is_non_paged, cudaStream_t stream); int is_non_paged, cudaStream_t stream);
/*! \brief Extract the first half (half_idx=0) or second half (half_idx=1) of a THD tensor.
*
* \warning This API is **experimental** and subject to change.
*
* \param[in] tensor Input tensor.
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[out] half Output tensor.
* \param[in] half_idx Whether to read first or second half of input tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_cp_thd_read_half_tensor(const NVTETensor &tensor, const NVTETensor &cu_seqlens, void nvte_cp_thd_read_half_tensor(const NVTETensor &tensor, const NVTETensor &cu_seqlens,
NVTETensor half, int half_idx, cudaStream_t stream); NVTETensor half, int half_idx, cudaStream_t stream);
/*! \brief Correct the second half of the softmax LSE (LogSumExp) for context parallelism.
*
* \warning This API is **experimental** and subject to change.
*
* \param[out] lse Output tensor.
* \param[in] lse_per_step Input tensor.
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[in] lse_packed Whether or not lse_per_step is packed.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_cp_thd_second_half_lse_correction(NVTETensor lse, const NVTETensor &lse_per_step, void nvte_cp_thd_second_half_lse_correction(NVTETensor lse, const NVTETensor &lse_per_step,
const NVTETensor &cu_seqlens, int lse_packed, const NVTETensor &cu_seqlens, int lse_packed,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Read the second half of the softmax LSE (LogSumExp) for context parallelism.
*
* \warning This API is **experimental** and subject to change.
*
* \param[in] lse Input tensor.
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[out] half_lse Output tensor.
* \param[in] lse_packed Whether or the softmax LSE is in packed format.
* \param[in] second_half_lse_seqlen Sequence length.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_cp_thd_read_second_half_lse(const NVTETensor &lse, const NVTETensor &cu_seqlens, void nvte_cp_thd_read_second_half_lse(const NVTETensor &lse, const NVTETensor &cu_seqlens,
NVTETensor half_lse, int lse_packed, NVTETensor half_lse, int lse_packed,
int second_half_lse_seqlen, cudaStream_t stream); int second_half_lse_seqlen, cudaStream_t stream);
/*! \brief Correct the THD format output of context parallelism in forward pass.
*
* \warning This API is **experimental** and subject to change.
*
* \param[out] out Output tensor.
* \param[in] out_per_step THD format output of context parallelism in forward pass.
* \param[in] lse Softmax LSE.
* \param[in] lse_per_step Softmax LSE per step.
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[in] only_second_half Whether or not to correct only second half.
* \param[in] lse_packed Whether or the softmax LSE is in packed format.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_cp_thd_out_correction(NVTETensor out, const NVTETensor &out_per_step, void nvte_cp_thd_out_correction(NVTETensor out, const NVTETensor &out_per_step,
const NVTETensor &lse, const NVTETensor &lse_per_step, const NVTETensor &lse, const NVTETensor &lse_per_step,
const NVTETensor &cu_seqlens, int only_second_half, int lse_packed, const NVTETensor &cu_seqlens, int only_second_half, int lse_packed,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Correct the THD format output of context parallelism in forward pass.
*
* \warning This API is **experimental** and subject to change.
*
* \param[out] grad Output tensor.
* \param[in] grad_per_step THD format gradient of context parallelism.
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[in] first_half One of ("add", "copy", "none") correction op for first half.
* \param[in] second_half One of ("add", "copy", "none") correction op for second half.
Must be different from first_half.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_cp_thd_grad_correction(NVTETensor grad, const NVTETensor &grad_per_step, void nvte_cp_thd_grad_correction(NVTETensor grad, const NVTETensor &grad_per_step,
const NVTETensor &cu_seqlens, const char *first_half, const NVTETensor &cu_seqlens, const char *first_half,
const char *second_half, cudaStream_t stream); const char *second_half, cudaStream_t stream);
/*! \brief Generate partitioned indices for inputs in THD format.
*
* \warning This API is **experimental** and subject to change.
*
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[out] output Output tensor.
* \param[in] total_tokens Total number of tokens.
* \param[in] world_size Total number of devices for context parallelism.
* \param[in] rank Device ID for current device.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_cp_thd_get_partitioned_indices(const NVTETensor &cu_seqlens, NVTETensor output, void nvte_cp_thd_get_partitioned_indices(const NVTETensor &cu_seqlens, NVTETensor output,
int total_tokens, int world_size, int rank, int total_tokens, int world_size, int rank,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Convert tensor from THD to BSHD format.
*
* \warning This API is **experimental** and subject to change.
*
* \param[in] tensor Input tensor.
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[out] new_tensor Output tensor.
* \param[in] b Batch size.
* \param[in] max_seq_len Maximum sequence length.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_convert_thd_to_bshd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETensor new_tensor, void nvte_convert_thd_to_bshd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETensor new_tensor,
int b, int max_seq_len, cudaStream_t stream); int b, int max_seq_len, cudaStream_t stream);
/*! \brief Convert tensor from BSHD to THD format.
*
* \warning This API is **experimental** and subject to change.
*
* \param[in] tensor Input tensor.
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[out] new_tensor Output tensor.
* \param[in] b Batch size.
* \param[in] max_seq_len Maximum sequence length.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_convert_bshd_to_thd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETensor new_tensor, void nvte_convert_bshd_to_thd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETensor new_tensor,
int t, cudaStream_t stream); int t, cudaStream_t stream);
/*! \brief Prepare QKV tensor for Flash Attention forward kernel.
*
* \warning This API is **experimental** and subject to change.
*
* \param[in] qkvi Input tensor.
* \param[out] qkv Output tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_prepare_flash_attn_fwd(NVTETensor qkvi, NVTETensor qkv, cudaStream_t stream); void nvte_prepare_flash_attn_fwd(NVTETensor qkvi, NVTETensor qkv, cudaStream_t stream);
/*! \brief Prepare QKV tensor for Flash Attention backward kernel.
*
* \warning This API is **experimental** and subject to change.
*
* \param[in] q Input query tensor.
* \param[in] k Input key tensor.
* \param[in] v Input value tensor.
* \param[out] qkv Output tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_prepare_flash_attn_bwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTETensor qkv, void nvte_prepare_flash_attn_bwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTETensor qkv,
cudaStream_t stream); cudaStream_t stream);
......
...@@ -132,12 +132,8 @@ void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D, ...@@ -132,12 +132,8 @@ void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
*/ */
namespace transformer_engine { namespace transformer_engine {
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
// In dcu, 2 stream is more better
constexpr int num_streams = 2;
// Add for batchgemm stream // Add for batchgemm stream
constexpr int num_batchgemm_streams = 1; constexpr int num_batchgemm_streams = 1;
#else
constexpr int num_streams = 4;
#endif #endif
/*! \brief TE/JAX cudaGraph requires the cuBLAS initialization to happen outside of the capturing /*! \brief TE/JAX cudaGraph requires the cuBLAS initialization to happen outside of the capturing
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file multi_stream.h
* \brief Functions for multi streams executions.
*/
#ifndef TRANSFORMER_ENGINE_MULTI_STREAM_H
#define TRANSFORMER_ENGINE_MULTI_STREAM_H
#ifdef __cplusplus
extern "C" {
#endif
/*! \brief Number of CUDA streams to use in multi-stream operations */
int nvte_get_num_compute_streams();
#ifdef __cplusplus
} // extern "C"
#endif
#endif // TRANSFORMER_ENGINE_MULTI_STREAM_H
...@@ -17,6 +17,25 @@ ...@@ -17,6 +17,25 @@
extern "C" { extern "C" {
#endif #endif
/*! \brief Computes L2 norm for a list of tensors.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
* \param[in] tensor_lists 2D array of input tensors.
* \param[in] num_tensor_lists Size (dim0) of tensor_lists.
* \param[in] num_tensors_per_list Size (dim1) of tensor_lists.
* \param[in] output Scratch space. Required size grows with number of inputs.
* \param[in] output_per_tensor Fixed size auxilliary scratch space.
* \param[out] ret L2 norm of all inputs.
* \param[out] ret_per_tensor L2 norm for each tensor.
* \param[in] per_tensor Whether to calculate per tensor or cumulative norm.
* \param[in] max_chunks_per_tensor Maximum number of chunks in any input tensor.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list, const size_t num_tensor_lists, const size_t num_tensors_per_list,
NVTETensor output, NVTETensor output_per_tensor, NVTETensor ret, NVTETensor output, NVTETensor output_per_tensor, NVTETensor ret,
...@@ -24,6 +43,28 @@ void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETen ...@@ -24,6 +43,28 @@ void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETen
int max_chunks_per_tensor, const int device_id, int max_chunks_per_tensor, const int device_id,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Computes L2 norm for a list of tensors after unscaling.
*
* Unscaling is only done for computing the L2 norm. The tensors themselves are not updated.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
* \param[in] tensor_lists 2D array of input tensors.
* \param[in] num_tensor_lists Size (dim0) of tensor_lists.
* \param[in] num_tensors_per_list Size (dim1) of tensor_lists.
* \param[in] output Scratch space. Required size grows with number of inputs.
* \param[in] output_per_tensor Fixed size auxilliary scratch space.
* \param[out] ret L2 norm of all inputs.
* \param[out] ret_per_tensor L2 norm for each tensor.
* \param[in] inv_scale Scalar for the unscaling operation.
* \param[in] per_tensor Whether to calculate per tensor or cumulative norm.
* \param[in] max_chunks_per_tensor Maximum number of chunks in any input tensor.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag, void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
NVTETensor **tensor_lists, const size_t num_tensor_lists, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, NVTETensor output, const size_t num_tensors_per_list, NVTETensor output,
...@@ -32,6 +73,27 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag, ...@@ -32,6 +73,27 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
int per_tensor, int max_chunks_per_tensor, int per_tensor, int max_chunks_per_tensor,
const int device_id, cudaStream_t stream); const int device_id, cudaStream_t stream);
/*! \brief Compute and apply gradient update to parameters for Adam optimizer.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
* \param[in,out] tensor_lists 2D array of input tensors.
* \param[in] num_tensor_lists Size (dim0) of tensor_lists.
* \param[in] num_tensors_per_list Size (dim1) of tensor_lists.
* \param[in] lr Learning rate.
* \param[in] beta1 Coefficient for first moment of gradient.
* \param[in] beta2 Coefficient for second moment of gradient.
* \param[in] epsilon Term added to the denominator for numerical stability.
* \param[in] step Iteration counter.
* \param[in] mode Whether to use AdamW (L2 penalty applied to params).
* \param[in] bias_correction Whether to apply correction factor for moment estimates.
* \param[in] weight_decay L2 penalty for weight decay.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list, const size_t num_tensor_lists, const size_t num_tensors_per_list,
const float lr, const float beta1, const float beta2, const float lr, const float beta1, const float beta2,
...@@ -39,12 +101,57 @@ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETenso ...@@ -39,12 +101,57 @@ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETenso
const int bias_correction, const float weight_decay, const int bias_correction, const float weight_decay,
const int device_id, cudaStream_t stream); const int device_id, cudaStream_t stream);
/*! \brief Compute and apply gradient update to parameters for Adam optimizer
* where the master parameters only store the remainder bits.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
* \param[in,out] tensor_lists 2D array of input tensors.
* \param[in] num_tensor_lists Size (dim0) of tensor_lists.
* \param[in] num_tensors_per_list Size (dim1) of tensor_lists.
* \param[in] lr Learning rate.
* \param[in] beta1 Coefficient for first moment of gradient.
* \param[in] beta2 Coefficient for second moment of gradient.
* \param[in] epsilon Term added to the denominator for numerical stability.
* \param[in] step Iteration counter.
* \param[in] mode Whether to use AdamW (L2 penalty applied to params).
* \param[in] bias_correction Whether to apply correction factor for moment estimates.
* \param[in] weight_decay L2 penalty for weight decay.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_multi_tensor_adam_param_remainder_cuda( void nvte_multi_tensor_adam_param_remainder_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, const float lr, const float beta1, const float beta2, const size_t num_tensors_per_list, const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode, const int bias_correction, const float epsilon, const int step, const int mode, const int bias_correction,
const float weight_decay, const int device_id, cudaStream_t stream); const float weight_decay, const int device_id, cudaStream_t stream);
/*! \brief Compute and apply gradient update to parameters for Adam optimizer
* when model parameters are in Float8 precision.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
* \param[in,out] tensor_lists 2D array of input tensors.
* \param[in] num_tensor_lists Size (dim0) of tensor_lists.
* \param[in] num_tensors_per_list Size (dim1) of tensor_lists.
* \param[in] lr Learning rate.
* \param[in] beta1 Coefficient for first moment of gradient.
* \param[in] beta2 Coefficient for second moment of gradient.
* \param[in] epsilon Term added to the denominator for numerical stability.
* \param[in] step Iteration counter.
* \param[in] mode Whether to use AdamW (L2 penalty applied to params).
* \param[in] bias_correction Whether to apply correction factor for moment estimates.
* \param[in] weight_decay L2 penalty for weight decay.
* \param[in] fp8_dtype FP8 data type for model parameters.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag, void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag,
NVTETensor **tensor_lists, const size_t num_tensor_lists, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, const float lr, const size_t num_tensors_per_list, const float lr,
...@@ -53,28 +160,125 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag, ...@@ -53,28 +160,125 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag,
const float weight_decay, const NVTEDType fp8_dtype, const float weight_decay, const NVTEDType fp8_dtype,
const int device_id, cudaStream_t stream); const int device_id, cudaStream_t stream);
/*! \brief Compute and apply gradient update to parameters for Adam optimizer
* with CUDA graph support and LR scheduling.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
* \param[in,out] tensor_lists 2D array of input tensors.
* \param[in] num_tensor_lists Size (dim0) of tensor_lists.
* \param[in] num_tensors_per_list Size (dim1) of tensor_lists.
* \param[in] lr Learning rate.
* \param[in] beta1 Coefficient for first moment of gradient.
* \param[in] beta2 Coefficient for second moment of gradient.
* \param[in] epsilon Term added to the denominator for numerical stability.
* \param[in] step Iteration counter.
* \param[in] mode Whether to use AdamW (L2 penalty applied to params).
* \param[in] bias_correction Whether to apply correction factor for moment estimates.
* \param[in] weight_decay L2 penalty for weight decay.
* \param[in] inv_scale Scalar for the unscaling operation.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_multi_tensor_adam_capturable_cuda( void nvte_multi_tensor_adam_capturable_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2, const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2,
const float epsilon, NVTETensor step, const int mode, const int bias_correction, const float epsilon, NVTETensor step, const int mode, const int bias_correction,
const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream); const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream);
/*! \brief Compute and apply gradient update to parameters for Adam optimizer
* with CUDA graph support, LR scheduling, and FP32 master weights.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
* \param[in,out] tensor_lists 2D array of input tensors.
* \param[in] num_tensor_lists Size (dim0) of tensor_lists.
* \param[in] num_tensors_per_list Size (dim1) of tensor_lists.
* \param[in] lr Learning rate.
* \param[in] beta1 Coefficient for first moment of gradient.
* \param[in] beta2 Coefficient for second moment of gradient.
* \param[in] epsilon Term added to the denominator for numerical stability.
* \param[in] step Iteration counter.
* \param[in] mode Whether to use AdamW (L2 penalty applied to params).
* \param[in] bias_correction Whether to apply correction factor for moment estimates.
* \param[in] weight_decay L2 penalty for weight decay.
* \param[in] inv_scale Scalar for the unscaling operation.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_multi_tensor_adam_capturable_master_cuda( void nvte_multi_tensor_adam_capturable_master_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2, const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2,
const float epsilon, NVTETensor step, const int mode, const int bias_correction, const float epsilon, NVTETensor step, const int mode, const int bias_correction,
const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream); const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream);
/*! \brief Compute and apply gradient update to parameters for SGD optimizer.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
* \param[in,out] tensor_lists 2D array of input tensors.
* \param[in] num_tensor_lists Size (dim0) of tensor_lists.
* \param[in] num_tensors_per_list Size (dim1) of tensor_lists.
* \param[in] wd Weight decay (L2 penalty).
* \param[in] momentum Momentum factor.
* \param[in] dampening Dampening factor.
* \param[in] lr Learning rate.
* \param[in] nesterov Whether or not to enable nesterov momentum.
* \param[in] first_run Whether momentum buffers have been initialized.
* \param[in] wd_after_momentum Whether to applied weight decay after momentum update.
* \param[in] scale Scalar for the scaling operation.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_multi_tensor_sgd_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, void nvte_multi_tensor_sgd_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list, const size_t num_tensor_lists, const size_t num_tensors_per_list,
float wd, float momentum, float dampening, float lr, int nesterov, float wd, float momentum, float dampening, float lr, int nesterov,
int first_run, int wd_after_momentum, float scale, int first_run, int wd_after_momentum, float scale,
const int device_id, cudaStream_t stream); const int device_id, cudaStream_t stream);
/*! \brief Check overflow and scale a list of tensors.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
* \param[in,out] tensor_lists 2D array of input tensors.
* \param[in] num_tensor_lists Size (dim0) of tensor_lists.
* \param[in] num_tensors_per_list Size (dim1) of tensor_lists.
* \param[in] scale Scalar for the scaling operation.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list, const size_t num_tensor_lists, const size_t num_tensors_per_list,
float scale, const int device_id, cudaStream_t stream); float scale, const int device_id, cudaStream_t stream);
/*! \brief Check overflow and scale a list of tensors.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
* \param[in,out] tensor_lists 2D array of input tensors.
* \param[in] num_tensor_lists Size (dim0) of tensor_lists.
* \param[in] num_tensors_per_list Size (dim1) of tensor_lists.
* \param[in] max_fp8 Maximum representible value in underlying FP8 format.
* \param[in] force_pow_2_scales Ensure scaling factors are a power of 2.
* \param[in] epsilon Term added to the denominator for numerical stability.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_multi_tensor_compute_scale_and_scale_inv_cuda( void nvte_multi_tensor_compute_scale_and_scale_inv_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, float max_fp8, int force_pow_2_scales, float epsilon, const size_t num_tensors_per_list, float max_fp8, int force_pow_2_scales, float epsilon,
......
...@@ -22,17 +22,18 @@ extern "C" { ...@@ -22,17 +22,18 @@ extern "C" {
* \brief TE datatype. * \brief TE datatype.
*/ */
enum NVTEDType { enum NVTEDType {
kNVTEByte = 0, /*!< Byte */ kNVTEByte = 0, /*!< Byte */
kNVTEInt16 = 1, /*!< 16-bit integer */ kNVTEInt16 = 1, /*!< 16-bit integer */
kNVTEInt32 = 2, /*!< 32-bit integer */ kNVTEInt32 = 2, /*!< 32-bit integer */
kNVTEInt64 = 3, /*!< 64-bit integer */ kNVTEInt64 = 3, /*!< 64-bit integer */
kNVTEFloat32 = 4, /*!< 32-bit float */ kNVTEFloat32 = 4, /*!< 32-bit float */
kNVTEFloat16 = 5, /*!< 16-bit float (E5M10) */ kNVTEFloat16 = 5, /*!< 16-bit float (E5M10) */
kNVTEBFloat16 = 6, /*!< 16-bit bfloat (E8M7) */ kNVTEBFloat16 = 6, /*!< 16-bit bfloat (E8M7) */
kNVTEFloat8E4M3 = 7, /*!< 8-bit float (E4M3) */ kNVTEFloat8E4M3 = 7, /*!< 8-bit float (E4M3) */
kNVTEFloat8E5M2 = 8, /*!< 8-bit float (E5M2) */ kNVTEFloat8E5M2 = 8, /*!< 8-bit float (E5M2) */
kNVTEFloat8E8M0 = 9, /*!< 8-bit float (E8M0) */ kNVTEFloat8E8M0 = 9, /*!< 8-bit float (E8M0) */
kNVTENumTypes /*!< Number of supported types */ kNVTEFloat4E2M1 = 10, /*!< 4-bit float (E2M1) */
kNVTENumTypes /*!< Number of supported types */
}; };
/*! \struct NVTEShape /*! \struct NVTEShape
...@@ -87,6 +88,10 @@ enum NVTEScalingMode { ...@@ -87,6 +88,10 @@ enum NVTEScalingMode {
*/ */
NVTE_BLOCK_SCALING_1D = 2, NVTE_BLOCK_SCALING_1D = 2,
NVTE_BLOCK_SCALING_2D = 3, NVTE_BLOCK_SCALING_2D = 3,
/*! Single NVFP4 scale per block of 16 contiguous elements in forward pass (FWD),
and single MXFP8 scale per block of 32 contiguous elements in backward pass (BWD).
*/
NVTE_FWD_NVFP4_BWD_MXFP8_SCALING = 4,
NVTE_INVALID_SCALING = 100 NVTE_INVALID_SCALING = 100
}; };
...@@ -177,6 +182,14 @@ size_t nvte_tensor_ndims(const NVTETensor tensor); ...@@ -177,6 +182,14 @@ size_t nvte_tensor_ndims(const NVTETensor tensor);
*/ */
size_t nvte_tensor_size(const NVTETensor tensor, const size_t dim); size_t nvte_tensor_size(const NVTETensor tensor, const size_t dim);
/*! \brief Get the byte size for the tensor.
*
* \param[in] tensor Tensor.
*
* \return Byte size of the tensor.
*/
size_t nvte_tensor_size_bytes(const NVTETensor tensor);
/*! \brief Get a tensor's total number of elements. /*! \brief Get a tensor's total number of elements.
* *
* \param[in] tensor Tensor. * \param[in] tensor Tensor.
...@@ -193,6 +206,14 @@ size_t nvte_tensor_numel(const NVTETensor tensor); ...@@ -193,6 +206,14 @@ size_t nvte_tensor_numel(const NVTETensor tensor);
*/ */
size_t nvte_tensor_element_size(const NVTETensor tensor); size_t nvte_tensor_element_size(const NVTETensor tensor);
/*! \brief Get the bit size for the tensor's data type.
*
* \param[in] tensor Tensor.
*
* \return Bit size of the tensor's data type.
*/
size_t nvte_tensor_element_size_bits(const NVTETensor tensor);
/*! \brief Get a tensor's data type. /*! \brief Get a tensor's data type.
* *
* \param[in] tensor Tensor. * \param[in] tensor Tensor.
...@@ -302,6 +323,13 @@ enum NVTEQuantizationConfigAttribute { ...@@ -302,6 +323,13 @@ enum NVTEQuantizationConfigAttribute {
conditional early even when captured in a static CUDA graph. conditional early even when captured in a static CUDA graph.
*/ */
kNVTEQuantizationConfigNoopTensor = 2, kNVTEQuantizationConfigNoopTensor = 2,
/*! Data format for an FP8 block-scaled tensor
*
* This is not the right design since the tensor format is a
* property of the tensor, not the quantization. This enum will
* likely be refactored away in the future.
*/
kNVTEQuantizationConfigFloat8BlockScaleTensorFormat = 3,
kNVTEQuantizationConfigNumAttributes kNVTEQuantizationConfigNumAttributes
}; };
...@@ -383,7 +411,8 @@ enum class DType { ...@@ -383,7 +411,8 @@ enum class DType {
kFloat8E4M3 = 7, kFloat8E4M3 = 7,
kFloat8E5M2 = 8, kFloat8E5M2 = 8,
kFloat8E8M0 = 9, kFloat8E8M0 = 9,
kInt8 = 10, kFloat4E2M1 = 10,
kInt8 = 11,
kNumTypes kNumTypes
}; };
...@@ -392,7 +421,16 @@ enum class DType { ...@@ -392,7 +421,16 @@ enum class DType {
* Return true if TE datatype is FP8 * Return true if TE datatype is FP8
* \param[in] DType TE Datatype of interest * \param[in] DType TE Datatype of interest
*/ */
bool is_fp8_dtype(const DType t); inline bool is_fp8_dtype(const DType t) {
return t == DType::kFloat8E4M3 || t == DType::kFloat8E5M2;
}
/*! \brief Check if TE datatype is FP4
*
* Return true if TE datatype is FP4
* \param[in] DType TE Datatype of interest
*/
inline bool is_fp4_dtype(const DType t) { return t == DType::kFloat4E2M1; }
/*! \struct TensorWrapper /*! \struct TensorWrapper
* \brief C++ wrapper for the NVTETensor class. * \brief C++ wrapper for the NVTETensor class.
...@@ -621,6 +659,15 @@ class TensorWrapper { ...@@ -621,6 +659,15 @@ class TensorWrapper {
return nvte_tensor_element_size(tensor_); return nvte_tensor_element_size(tensor_);
} }
/*! \brief Get the tensor's element size in bits.
*
* \return Element size in bits.
*/
size_t element_size_bits() const noexcept {
if (tensor_ == nullptr) return 0;
return nvte_tensor_element_size_bits(tensor_);
}
/*! \brief Get the tensor's allocated size in bytes. This will return 0 for tensors with nullptr /*! \brief Get the tensor's allocated size in bytes. This will return 0 for tensors with nullptr
* data even if the TensorWrapper has a non-zero shape and valid dtype. * data even if the TensorWrapper has a non-zero shape and valid dtype.
* *
...@@ -628,7 +675,7 @@ class TensorWrapper { ...@@ -628,7 +675,7 @@ class TensorWrapper {
*/ */
size_t bytes() const noexcept { size_t bytes() const noexcept {
if (tensor_ == nullptr || this->dptr() == nullptr) return 0; if (tensor_ == nullptr || this->dptr() == nullptr) return 0;
return nvte_tensor_numel(tensor_) * nvte_tensor_element_size(tensor_); return nvte_tensor_size_bytes(tensor_);
} }
/*! \brief Get the data type of this TensorWrapper. /*! \brief Get the data type of this TensorWrapper.
...@@ -722,6 +769,16 @@ class TensorWrapper { ...@@ -722,6 +769,16 @@ class TensorWrapper {
NVTETensor tensor_ = nullptr; NVTETensor tensor_ = nullptr;
}; };
/*! \enum Float8BlockScaleTensorFormat
* \brief Data format for an FP8 block-scaled tensor
*/
enum class Float8BlockScaleTensorFormat {
/*! FP8 data is transposed if needed and scales are swizzled */
GEMM_READY = 0,
/*! FP8 data is untransposed and scales are not swizzled or padded */
COMPACT = 1
};
/*! \struct QuantizationConfigWrapper /*! \struct QuantizationConfigWrapper
* \brief C++ wrapper for NVTEQuantizationConfigWrapper. * \brief C++ wrapper for NVTEQuantizationConfigWrapper.
*/ */
...@@ -775,6 +832,13 @@ class QuantizationConfigWrapper { ...@@ -775,6 +832,13 @@ class QuantizationConfigWrapper {
sizeof(NVTETensor)); sizeof(NVTETensor));
} }
/*! \brief Set FP8 block-scaled tensor format */
void set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat format) {
nvte_set_quantization_config_attribute(config_,
kNVTEQuantizationConfigFloat8BlockScaleTensorFormat,
&format, sizeof(Float8BlockScaleTensorFormat));
}
private: private:
/*! \brief Wrapped NVTEQuantizationConfig. */ /*! \brief Wrapped NVTEQuantizationConfig. */
NVTEQuantizationConfig config_ = nullptr; NVTEQuantizationConfig config_ = nullptr;
......
...@@ -807,7 +807,7 @@ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETenso ...@@ -807,7 +807,7 @@ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETenso
using namespace transformer_engine; using namespace transformer_engine;
multi_tensor_adam::multi_tensor_adam_cuda( multi_tensor_adam::multi_tensor_adam_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag), chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2, convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2,
epsilon, step, mode, bias_correction, weight_decay, device_id, stream); epsilon, step, mode, bias_correction, weight_decay, device_id, stream);
} }
...@@ -821,7 +821,7 @@ void nvte_multi_tensor_adam_param_remainder_cuda( ...@@ -821,7 +821,7 @@ void nvte_multi_tensor_adam_param_remainder_cuda(
using namespace transformer_engine; using namespace transformer_engine;
multi_tensor_adam::multi_tensor_adam_param_remainder_cuda( multi_tensor_adam::multi_tensor_adam_param_remainder_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag), chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2, convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2,
epsilon, step, mode, bias_correction, weight_decay, device_id, stream); epsilon, step, mode, bias_correction, weight_decay, device_id, stream);
} }
...@@ -837,7 +837,7 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag, ...@@ -837,7 +837,7 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag,
using namespace transformer_engine; using namespace transformer_engine;
multi_tensor_adam::multi_tensor_adam_fp8_cuda( multi_tensor_adam::multi_tensor_adam_fp8_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag), chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2, convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2,
epsilon, step, mode, bias_correction, weight_decay, static_cast<DType>(fp8_dtype), device_id, epsilon, step, mode, bias_correction, weight_decay, static_cast<DType>(fp8_dtype), device_id,
stream); stream);
...@@ -852,11 +852,10 @@ void nvte_multi_tensor_adam_capturable_cuda( ...@@ -852,11 +852,10 @@ void nvte_multi_tensor_adam_capturable_cuda(
using namespace transformer_engine; using namespace transformer_engine;
multi_tensor_adam::multi_tensor_adam_capturable_cuda( multi_tensor_adam::multi_tensor_adam_capturable_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag), chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
*reinterpret_cast<Tensor *>(lr), beta1, beta2, epsilon, *reinterpret_cast<Tensor *>(step), *convertNVTETensorCheck(lr), beta1, beta2, epsilon, *convertNVTETensorCheck(step), mode,
mode, bias_correction, weight_decay, *reinterpret_cast<Tensor *>(inv_scale), device_id, bias_correction, weight_decay, *convertNVTETensorCheck(inv_scale), device_id, stream);
stream);
} }
void nvte_multi_tensor_adam_capturable_master_cuda( void nvte_multi_tensor_adam_capturable_master_cuda(
...@@ -868,9 +867,8 @@ void nvte_multi_tensor_adam_capturable_master_cuda( ...@@ -868,9 +867,8 @@ void nvte_multi_tensor_adam_capturable_master_cuda(
using namespace transformer_engine; using namespace transformer_engine;
multi_tensor_adam::multi_tensor_adam_capturable_master_cuda( multi_tensor_adam::multi_tensor_adam_capturable_master_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag), chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
*reinterpret_cast<Tensor *>(lr), beta1, beta2, epsilon, *reinterpret_cast<Tensor *>(step), *convertNVTETensorCheck(lr), beta1, beta2, epsilon, *convertNVTETensorCheck(step), mode,
mode, bias_correction, weight_decay, *reinterpret_cast<Tensor *>(inv_scale), device_id, bias_correction, weight_decay, *convertNVTETensorCheck(inv_scale), device_id, stream);
stream);
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment