"vscode:/vscode.git/clone" did not exist on "a57d13cc967d04478cebcfa2c3440fb167b21c6d"
Unverified Commit 4292653c authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Avoid memory allocations and deallocations when creating NVTETensor (#1813)



* Changed the Tensor allocation strategy
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fixes
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Disable debug flag
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix the double free error
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fixed pyTorch recipe extension
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Hide TensorAllocator and fix the usage in LayerNorm
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Cleaning
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix permutation
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

---------
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 41909dc8
...@@ -200,7 +200,7 @@ std::vector<std::vector<Tensor *>> convert_tensor_array(NVTETensor **nvte_tensor ...@@ -200,7 +200,7 @@ std::vector<std::vector<Tensor *>> convert_tensor_array(NVTETensor **nvte_tensor
for (size_t i = 0; i < outer_size; ++i) { for (size_t i = 0; i < outer_size; ++i) {
ret.emplace_back(); ret.emplace_back();
for (size_t j = 0; j < inner_size; ++j) { for (size_t j = 0; j < inner_size; ++j) {
ret.back().push_back(reinterpret_cast<Tensor *>(nvte_tensors[i][j])); ret.back().push_back(convertNVTETensor(nvte_tensors[i][j]));
} }
} }
return ret; return ret;
......
...@@ -89,9 +89,16 @@ struct SimpleTensor { ...@@ -89,9 +89,16 @@ struct SimpleTensor {
} }
return acc; return acc;
} }
void clear() {
dptr = nullptr;
shape.resize(0);
dtype = DType::kFloat32;
}
}; };
struct Tensor { struct Tensor {
public:
SimpleTensor data; SimpleTensor data;
SimpleTensor columnwise_data; SimpleTensor columnwise_data;
SimpleTensor amax; SimpleTensor amax;
...@@ -99,8 +106,8 @@ struct Tensor { ...@@ -99,8 +106,8 @@ struct Tensor {
SimpleTensor scale_inv; SimpleTensor scale_inv;
SimpleTensor columnwise_scale_inv; SimpleTensor columnwise_scale_inv;
public:
NVTEScalingMode scaling_mode; NVTEScalingMode scaling_mode;
NVTETensor nvte_tensor;
Tensor() Tensor()
: data(), : data(),
...@@ -109,7 +116,20 @@ struct Tensor { ...@@ -109,7 +116,20 @@ struct Tensor {
scale(nullptr, {1}, DType::kFloat32), scale(nullptr, {1}, DType::kFloat32),
scale_inv(nullptr, {1}, DType::kFloat32), scale_inv(nullptr, {1}, DType::kFloat32),
columnwise_scale_inv(nullptr, {1}, DType::kFloat32), columnwise_scale_inv(nullptr, {1}, DType::kFloat32),
scaling_mode(NVTE_DELAYED_TENSOR_SCALING) {} scaling_mode(NVTE_DELAYED_TENSOR_SCALING),
nvte_tensor(0) {}
void clear() {
data.clear();
columnwise_data.clear();
amax.clear();
scale.clear();
scale_inv.clear();
columnwise_scale_inv.clear();
scaling_mode = NVTE_DELAYED_TENSOR_SCALING;
}
explicit operator NVTETensor() const noexcept { return nvte_tensor; }
size_t numel() const { size_t numel() const {
size_t acc = 1; size_t acc = 1;
...@@ -620,6 +640,8 @@ bool is_supported_by_CC_100(); ...@@ -620,6 +640,8 @@ bool is_supported_by_CC_100();
std::vector<std::vector<Tensor *>> convert_tensor_array(NVTETensor **nvte_tensors, std::vector<std::vector<Tensor *>> convert_tensor_array(NVTETensor **nvte_tensors,
size_t outer_size, size_t inner_size); size_t outer_size, size_t inner_size);
Tensor *convertNVTETensor(const NVTETensor tensor);
Tensor *convertNVTETensorCheck(const NVTETensor tensor);
} // namespace transformer_engine } // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_COMMON_H_ #endif // TRANSFORMER_ENGINE_COMMON_COMMON_H_
...@@ -677,9 +677,9 @@ void nvte_cp_thd_read_half_tensor(const NVTETensor &tensor, const NVTETensor &cu ...@@ -677,9 +677,9 @@ void nvte_cp_thd_read_half_tensor(const NVTETensor &tensor, const NVTETensor &cu
NVTE_API_CALL(nvte_thd_read_half_tensor); NVTE_API_CALL(nvte_thd_read_half_tensor);
using namespace transformer_engine; using namespace transformer_engine;
context_parallel::thd_read_half_tensor(*reinterpret_cast<Tensor *>(tensor), context_parallel::thd_read_half_tensor(*convertNVTETensorCheck(tensor),
*reinterpret_cast<Tensor *>(cu_seqlens), *convertNVTETensorCheck(cu_seqlens),
*reinterpret_cast<Tensor *>(half), half_idx, stream); *convertNVTETensorCheck(half), half_idx, stream);
} }
void nvte_cp_thd_second_half_lse_correction(NVTETensor lse, const NVTETensor &lse_per_step, void nvte_cp_thd_second_half_lse_correction(NVTETensor lse, const NVTETensor &lse_per_step,
...@@ -689,8 +689,8 @@ void nvte_cp_thd_second_half_lse_correction(NVTETensor lse, const NVTETensor &ls ...@@ -689,8 +689,8 @@ void nvte_cp_thd_second_half_lse_correction(NVTETensor lse, const NVTETensor &ls
using namespace transformer_engine; using namespace transformer_engine;
context_parallel::thd_second_half_lse_correction( context_parallel::thd_second_half_lse_correction(
*reinterpret_cast<Tensor *>(lse), *reinterpret_cast<Tensor *>(lse_per_step), *convertNVTETensorCheck(lse), *convertNVTETensorCheck(lse_per_step),
*reinterpret_cast<Tensor *>(cu_seqlens), lse_packed, stream); *convertNVTETensorCheck(cu_seqlens), lse_packed, stream);
} }
void nvte_cp_thd_read_second_half_lse(const NVTETensor &lse, const NVTETensor &cu_seqlens, void nvte_cp_thd_read_second_half_lse(const NVTETensor &lse, const NVTETensor &cu_seqlens,
...@@ -700,8 +700,8 @@ void nvte_cp_thd_read_second_half_lse(const NVTETensor &lse, const NVTETensor &c ...@@ -700,8 +700,8 @@ void nvte_cp_thd_read_second_half_lse(const NVTETensor &lse, const NVTETensor &c
using namespace transformer_engine; using namespace transformer_engine;
context_parallel::thd_read_second_half_lse( context_parallel::thd_read_second_half_lse(
*reinterpret_cast<Tensor *>(lse), *reinterpret_cast<Tensor *>(cu_seqlens), *convertNVTETensorCheck(lse), *convertNVTETensorCheck(cu_seqlens),
*reinterpret_cast<Tensor *>(half_lse), lse_packed, second_half_lse_seqlen, stream); *convertNVTETensorCheck(half_lse), lse_packed, second_half_lse_seqlen, stream);
} }
void nvte_cp_thd_out_correction(NVTETensor out, const NVTETensor &out_per_step, void nvte_cp_thd_out_correction(NVTETensor out, const NVTETensor &out_per_step,
...@@ -712,9 +712,9 @@ void nvte_cp_thd_out_correction(NVTETensor out, const NVTETensor &out_per_step, ...@@ -712,9 +712,9 @@ void nvte_cp_thd_out_correction(NVTETensor out, const NVTETensor &out_per_step,
using namespace transformer_engine; using namespace transformer_engine;
context_parallel::thd_out_correction( context_parallel::thd_out_correction(
*reinterpret_cast<Tensor *>(out), *reinterpret_cast<Tensor *>(out_per_step), *convertNVTETensorCheck(out), *convertNVTETensorCheck(out_per_step),
*reinterpret_cast<Tensor *>(lse), *reinterpret_cast<Tensor *>(lse_per_step), *convertNVTETensorCheck(lse), *convertNVTETensorCheck(lse_per_step),
*reinterpret_cast<Tensor *>(cu_seqlens), only_second_half, lse_packed, stream); *convertNVTETensorCheck(cu_seqlens), only_second_half, lse_packed, stream);
} }
void nvte_cp_thd_grad_correction(NVTETensor grad, const NVTETensor &grad_per_step, void nvte_cp_thd_grad_correction(NVTETensor grad, const NVTETensor &grad_per_step,
...@@ -727,8 +727,8 @@ void nvte_cp_thd_grad_correction(NVTETensor grad, const NVTETensor &grad_per_ste ...@@ -727,8 +727,8 @@ void nvte_cp_thd_grad_correction(NVTETensor grad, const NVTETensor &grad_per_ste
std::string second_half_str(second_half); std::string second_half_str(second_half);
context_parallel::thd_grad_correction( context_parallel::thd_grad_correction(
*reinterpret_cast<Tensor *>(grad), *reinterpret_cast<Tensor *>(grad_per_step), *convertNVTETensorCheck(grad), *convertNVTETensorCheck(grad_per_step),
*reinterpret_cast<Tensor *>(cu_seqlens), first_half_str, second_half_str, stream); *convertNVTETensorCheck(cu_seqlens), first_half_str, second_half_str, stream);
} }
void nvte_cp_thd_get_partitioned_indices(const NVTETensor &cu_seqlens, NVTETensor output, void nvte_cp_thd_get_partitioned_indices(const NVTETensor &cu_seqlens, NVTETensor output,
...@@ -737,7 +737,7 @@ void nvte_cp_thd_get_partitioned_indices(const NVTETensor &cu_seqlens, NVTETenso ...@@ -737,7 +737,7 @@ void nvte_cp_thd_get_partitioned_indices(const NVTETensor &cu_seqlens, NVTETenso
NVTE_API_CALL(nvte_thd_get_partitioned_indices); NVTE_API_CALL(nvte_thd_get_partitioned_indices);
using namespace transformer_engine; using namespace transformer_engine;
context_parallel::thd_get_partitioned_indices(*reinterpret_cast<Tensor *>(cu_seqlens), context_parallel::thd_get_partitioned_indices(*convertNVTETensorCheck(cu_seqlens),
*reinterpret_cast<Tensor *>(output), total_tokens, *convertNVTETensorCheck(output), total_tokens,
world_size, rank, stream); world_size, rank, stream);
} }
...@@ -138,8 +138,8 @@ void nvte_prepare_flash_attn_fwd(NVTETensor qkvi, NVTETensor qkv, cudaStream_t s ...@@ -138,8 +138,8 @@ void nvte_prepare_flash_attn_fwd(NVTETensor qkvi, NVTETensor qkv, cudaStream_t s
NVTE_API_CALL(nvte_prepare_flash_attn_fwd); NVTE_API_CALL(nvte_prepare_flash_attn_fwd);
using namespace transformer_engine; using namespace transformer_engine;
flash_attention::prepare_flash_attn_fwd(*reinterpret_cast<Tensor *>(qkvi), flash_attention::prepare_flash_attn_fwd(*convertNVTETensorCheck(qkvi),
*reinterpret_cast<Tensor *>(qkv), stream); *convertNVTETensorCheck(qkv), stream);
} }
void nvte_prepare_flash_attn_bwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTETensor qkv, void nvte_prepare_flash_attn_bwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTETensor qkv,
...@@ -147,7 +147,7 @@ void nvte_prepare_flash_attn_bwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTET ...@@ -147,7 +147,7 @@ void nvte_prepare_flash_attn_bwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTET
NVTE_API_CALL(nvte_prepare_flash_attn_bwd); NVTE_API_CALL(nvte_prepare_flash_attn_bwd);
using namespace transformer_engine; using namespace transformer_engine;
flash_attention::prepare_flash_attn_bwd( flash_attention::prepare_flash_attn_bwd(*convertNVTETensorCheck(q), *convertNVTETensorCheck(k),
*reinterpret_cast<Tensor *>(q), *reinterpret_cast<Tensor *>(k), *convertNVTETensorCheck(v), *convertNVTETensorCheck(qkv),
*reinterpret_cast<Tensor *>(v), *reinterpret_cast<Tensor *>(qkv), stream); stream);
} }
...@@ -392,14 +392,14 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, ...@@ -392,14 +392,14 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked); NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens = reinterpret_cast<const Tensor *>(cu_seqlens); const Tensor *input_cu_seqlens = convertNVTETensorCheck(cu_seqlens);
const Tensor *input_cu_seqlens_padded = reinterpret_cast<const Tensor *>(cu_seqlens_padded); const Tensor *input_cu_seqlens_padded = convertNVTETensorCheck(cu_seqlens_padded);
const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(rng_state); const Tensor *input_rng_state = convertNVTETensorCheck(rng_state);
const Tensor *input_QKV = reinterpret_cast<const Tensor *>(QKV); const Tensor *input_QKV = convertNVTETensorCheck(QKV);
const Tensor *input_Bias = reinterpret_cast<const Tensor *>(Bias); const Tensor *input_Bias = convertNVTETensorCheck(Bias);
Tensor *input_output_S = reinterpret_cast<Tensor *>(S); Tensor *input_output_S = convertNVTETensorCheck(S);
Tensor *output_O = reinterpret_cast<Tensor *>(O); Tensor *output_O = convertNVTETensorCheck(O);
Tensor *wkspace = reinterpret_cast<Tensor *>(workspace); Tensor *wkspace = convertNVTETensor(workspace);
auto ndim = input_QKV->data.shape.size(); auto ndim = input_QKV->data.shape.size();
size_t b = input_cu_seqlens->data.shape[0] - 1; size_t b = input_cu_seqlens->data.shape[0] - 1;
...@@ -472,16 +472,16 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con ...@@ -472,16 +472,16 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked); NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens = reinterpret_cast<const Tensor *>(cu_seqlens); const Tensor *input_cu_seqlens = convertNVTETensorCheck(cu_seqlens);
const Tensor *input_cu_seqlens_padded = reinterpret_cast<const Tensor *>(cu_seqlens_padded); const Tensor *input_cu_seqlens_padded = convertNVTETensorCheck(cu_seqlens_padded);
const Tensor *input_QKV = reinterpret_cast<const Tensor *>(QKV); const Tensor *input_QKV = convertNVTETensorCheck(QKV);
const Tensor *input_O = reinterpret_cast<const Tensor *>(O); const Tensor *input_O = convertNVTETensorCheck(O);
const Tensor *input_dO = reinterpret_cast<const Tensor *>(dO); const Tensor *input_dO = convertNVTETensorCheck(dO);
const Tensor *input_S = reinterpret_cast<const Tensor *>(S); const Tensor *input_S = convertNVTETensorCheck(S);
Tensor *input_output_dP = reinterpret_cast<Tensor *>(dP); Tensor *input_output_dP = convertNVTETensorCheck(dP);
Tensor *output_dQKV = reinterpret_cast<Tensor *>(dQKV); Tensor *output_dQKV = convertNVTETensorCheck(dQKV);
Tensor *output_dBias = reinterpret_cast<Tensor *>(dBias); Tensor *output_dBias = convertNVTETensorCheck(dBias);
Tensor *wkspace = reinterpret_cast<Tensor *>(workspace); Tensor *wkspace = convertNVTETensor(workspace);
auto ndim = input_QKV->data.shape.size(); auto ndim = input_QKV->data.shape.size();
size_t b = input_cu_seqlens->data.shape[0] - 1; size_t b = input_cu_seqlens->data.shape[0] - 1;
...@@ -510,7 +510,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con ...@@ -510,7 +510,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
fused_attn_max_512_bwd_qkvpacked( fused_attn_max_512_bwd_qkvpacked(
b, h, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_QKV, b, h, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_QKV,
input_dO, output_S, output_dQKV, output_dBias, input_cu_seqlens, wkspace, stream, handle); input_dO, output_S, output_dQKV, output_dBias, input_cu_seqlens, wkspace, stream, handle);
...@@ -519,13 +519,13 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con ...@@ -519,13 +519,13 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
#endif #endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Tensor *input_Bias, *input_rng_state; Tensor *input_Bias, *input_rng_state;
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
input_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]); input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
input_Bias = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[2]); input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
} else { } else {
input_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]); input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
} }
fused_attn_arbitrary_seqlen_bwd_qkvpacked( fused_attn_arbitrary_seqlen_bwd_qkvpacked(
b, h, max_seqlen, d, t, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, b, h, max_seqlen, d, t, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
...@@ -540,9 +540,9 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con ...@@ -540,9 +540,9 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
#endif #endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
const Tensor *input_M = reinterpret_cast<const Tensor *>(Aux_CTX_Tensors->tensors[0]); const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
const Tensor *input_ZInv = reinterpret_cast<const Tensor *>(Aux_CTX_Tensors->tensors[1]); const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(Aux_CTX_Tensors->tensors[2]); const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
fused_attn_fp8_bwd_qkvpacked(b, h, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type, fused_attn_fp8_bwd_qkvpacked(b, h, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type,
attn_mask_type, input_QKV, input_O, input_dO, input_M, input_ZInv, attn_mask_type, input_QKV, input_O, input_dO, input_M, input_ZInv,
input_S, input_output_dP, output_dQKV, input_cu_seqlens, input_S, input_output_dP, output_dQKV, input_cu_seqlens,
...@@ -566,19 +566,19 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -566,19 +566,19 @@ void nvte_fused_attn_fwd_kvpacked(
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked); NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q); const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor *>(cu_seqlens_kv); const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv);
const Tensor *input_cu_seqlens_q_padded = reinterpret_cast<const Tensor *>(cu_seqlens_q_padded); const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded);
const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast<const Tensor *>(cu_seqlens_kv_padded); const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded);
const Tensor *input_page_table_k = reinterpret_cast<const Tensor *>(page_table_k); const Tensor *input_page_table_k = convertNVTETensorCheck(page_table_k);
const Tensor *input_page_table_v = reinterpret_cast<const Tensor *>(page_table_v); const Tensor *input_page_table_v = convertNVTETensorCheck(page_table_v);
const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(rng_state); const Tensor *input_rng_state = convertNVTETensorCheck(rng_state);
const Tensor *input_Q = reinterpret_cast<const Tensor *>(Q); const Tensor *input_Q = convertNVTETensorCheck(Q);
const Tensor *input_KV = reinterpret_cast<const Tensor *>(KV); const Tensor *input_KV = convertNVTETensorCheck(KV);
const Tensor *input_Bias = reinterpret_cast<const Tensor *>(Bias); const Tensor *input_Bias = convertNVTETensorCheck(Bias);
Tensor *input_output_S = reinterpret_cast<Tensor *>(S); Tensor *input_output_S = convertNVTETensorCheck(S);
Tensor *output_O = reinterpret_cast<Tensor *>(O); Tensor *output_O = convertNVTETensorCheck(O);
Tensor *wkspace = reinterpret_cast<Tensor *>(workspace); Tensor *wkspace = convertNVTETensor(workspace);
size_t b = input_cu_seqlens_q->data.shape[0] - 1; size_t b = input_cu_seqlens_q->data.shape[0] - 1;
auto ndim = input_Q->data.shape.size(); auto ndim = input_Q->data.shape.size();
...@@ -686,20 +686,20 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -686,20 +686,20 @@ void nvte_fused_attn_bwd_kvpacked(
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked); NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q); const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor *>(cu_seqlens_kv); const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv);
const Tensor *input_cu_seqlens_q_padded = reinterpret_cast<const Tensor *>(cu_seqlens_q_padded); const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded);
const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast<const Tensor *>(cu_seqlens_kv_padded); const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded);
const Tensor *input_Q = reinterpret_cast<const Tensor *>(Q); const Tensor *input_Q = convertNVTETensorCheck(Q);
const Tensor *input_KV = reinterpret_cast<const Tensor *>(KV); const Tensor *input_KV = convertNVTETensorCheck(KV);
const Tensor *input_O = reinterpret_cast<const Tensor *>(O); const Tensor *input_O = convertNVTETensorCheck(O);
const Tensor *input_dO = reinterpret_cast<const Tensor *>(dO); const Tensor *input_dO = convertNVTETensorCheck(dO);
const Tensor *input_S = reinterpret_cast<const Tensor *>(S); const Tensor *input_S = convertNVTETensorCheck(S);
Tensor *input_output_dP = reinterpret_cast<Tensor *>(dP); Tensor *input_output_dP = convertNVTETensorCheck(dP);
Tensor *output_dQ = reinterpret_cast<Tensor *>(dQ); Tensor *output_dQ = convertNVTETensorCheck(dQ);
Tensor *output_dKV = reinterpret_cast<Tensor *>(dKV); Tensor *output_dKV = convertNVTETensorCheck(dKV);
Tensor *output_dBias = reinterpret_cast<Tensor *>(dBias); Tensor *output_dBias = convertNVTETensorCheck(dBias);
Tensor *wkspace = reinterpret_cast<Tensor *>(workspace); Tensor *wkspace = convertNVTETensor(workspace);
size_t b = input_cu_seqlens_q->data.shape[0] - 1; size_t b = input_cu_seqlens_q->data.shape[0] - 1;
auto ndim = input_Q->data.shape.size(); auto ndim = input_Q->data.shape.size();
...@@ -736,7 +736,7 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -736,7 +736,7 @@ void nvte_fused_attn_bwd_kvpacked(
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
fused_attn_max_512_bwd_kvpacked( fused_attn_max_512_bwd_kvpacked(
b, h_q, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type, b, h_q, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type,
attn_mask_type, input_Q, input_KV, input_dO, output_S, output_dQ, output_dKV, output_dBias, attn_mask_type, input_Q, input_KV, input_dO, output_S, output_dQ, output_dKV, output_dBias,
...@@ -746,13 +746,13 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -746,13 +746,13 @@ void nvte_fused_attn_bwd_kvpacked(
#endif #endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8903) #if (CUDNN_VERSION >= 8903)
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Tensor *input_Bias, *input_rng_state; Tensor *input_Bias, *input_rng_state;
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
input_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]); input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
input_Bias = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[2]); input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
} else { } else {
input_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]); input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
} }
fused_attn_arbitrary_seqlen_bwd_kvpacked( fused_attn_arbitrary_seqlen_bwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, attn_scale, dropout, qkv_layout, b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, attn_scale, dropout, qkv_layout,
...@@ -768,9 +768,9 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -768,9 +768,9 @@ void nvte_fused_attn_bwd_kvpacked(
#endif #endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
const Tensor *input_M = reinterpret_cast<const Tensor *>(Aux_CTX_Tensors->tensors[0]); const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
const Tensor *input_ZInv = reinterpret_cast<const Tensor *>(Aux_CTX_Tensors->tensors[1]); const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(Aux_CTX_Tensors->tensors[2]); const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
fused_attn_fp8_bwd_kvpacked(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, fused_attn_fp8_bwd_kvpacked(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout,
qkv_layout, bias_type, attn_mask_type, input_Q, input_KV, input_O, qkv_layout, bias_type, attn_mask_type, input_Q, input_KV, input_O,
input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ, input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ,
...@@ -797,20 +797,20 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -797,20 +797,20 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd); NVTE_API_CALL(nvte_flash_attn_fwd);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q); const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor *>(cu_seqlens_kv); const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv);
const Tensor *input_cu_seqlens_q_padded = reinterpret_cast<const Tensor *>(cu_seqlens_q_padded); const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded);
const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast<const Tensor *>(cu_seqlens_kv_padded); const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded);
const Tensor *input_page_table_k = reinterpret_cast<const Tensor *>(page_table_k); const Tensor *input_page_table_k = convertNVTETensorCheck(page_table_k);
const Tensor *input_page_table_v = reinterpret_cast<const Tensor *>(page_table_v); const Tensor *input_page_table_v = convertNVTETensorCheck(page_table_v);
const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(rng_state); const Tensor *input_rng_state = convertNVTETensorCheck(rng_state);
const Tensor *input_Q = reinterpret_cast<const Tensor *>(Q); const Tensor *input_Q = convertNVTETensorCheck(Q);
const Tensor *input_K = reinterpret_cast<const Tensor *>(K); const Tensor *input_K = convertNVTETensorCheck(K);
const Tensor *input_V = reinterpret_cast<const Tensor *>(V); const Tensor *input_V = convertNVTETensorCheck(V);
const Tensor *input_Bias = reinterpret_cast<const Tensor *>(Bias); const Tensor *input_Bias = convertNVTETensorCheck(Bias);
Tensor *input_output_S = reinterpret_cast<Tensor *>(S); Tensor *input_output_S = convertNVTETensorCheck(S);
Tensor *output_O = reinterpret_cast<Tensor *>(O); Tensor *output_O = convertNVTETensorCheck(O);
Tensor *wkspace = reinterpret_cast<Tensor *>(workspace); Tensor *wkspace = convertNVTETensor(workspace);
auto ndim = input_Q->data.shape.size(); auto ndim = input_Q->data.shape.size();
auto ndim_kv = input_K->data.shape.size(); auto ndim_kv = input_K->data.shape.size();
...@@ -914,22 +914,22 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -914,22 +914,22 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd); NVTE_API_CALL(nvte_flash_attn_bwd);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q); const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor *>(cu_seqlens_kv); const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv);
const Tensor *input_cu_seqlens_q_padded = reinterpret_cast<const Tensor *>(cu_seqlens_q_padded); const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded);
const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast<const Tensor *>(cu_seqlens_kv_padded); const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded);
const Tensor *input_Q = reinterpret_cast<const Tensor *>(Q); const Tensor *input_Q = convertNVTETensorCheck(Q);
const Tensor *input_K = reinterpret_cast<const Tensor *>(K); const Tensor *input_K = convertNVTETensorCheck(K);
const Tensor *input_V = reinterpret_cast<const Tensor *>(V); const Tensor *input_V = convertNVTETensorCheck(V);
const Tensor *input_O = reinterpret_cast<const Tensor *>(O); const Tensor *input_O = convertNVTETensorCheck(O);
const Tensor *input_dO = reinterpret_cast<const Tensor *>(dO); const Tensor *input_dO = convertNVTETensorCheck(dO);
const Tensor *input_S = reinterpret_cast<const Tensor *>(S); const Tensor *input_S = convertNVTETensorCheck(S);
Tensor *input_output_dP = reinterpret_cast<Tensor *>(dP); Tensor *input_output_dP = convertNVTETensorCheck(dP);
Tensor *output_dQ = reinterpret_cast<Tensor *>(dQ); Tensor *output_dQ = convertNVTETensorCheck(dQ);
Tensor *output_dK = reinterpret_cast<Tensor *>(dK); Tensor *output_dK = convertNVTETensorCheck(dK);
Tensor *output_dV = reinterpret_cast<Tensor *>(dV); Tensor *output_dV = convertNVTETensorCheck(dV);
Tensor *output_dBias = reinterpret_cast<Tensor *>(dBias); Tensor *output_dBias = convertNVTETensorCheck(dBias);
Tensor *wkspace = reinterpret_cast<Tensor *>(workspace); Tensor *wkspace = convertNVTETensor(workspace);
auto ndim = input_Q->data.shape.size(); auto ndim = input_Q->data.shape.size();
auto ndim_kv = input_K->data.shape.size(); auto ndim_kv = input_K->data.shape.size();
...@@ -959,7 +959,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -959,7 +959,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901) #if (CUDNN_VERSION >= 8901)
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
fused_attn_max_512_bwd(b, h_q, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout, fused_attn_max_512_bwd(b, h_q, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout,
qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V,
input_dO, output_S, output_dQ, output_dK, output_dV, output_dBias, input_dO, output_S, output_dQ, output_dK, output_dV, output_dBias,
...@@ -969,13 +969,13 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -969,13 +969,13 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
#endif #endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Tensor *input_Bias, *input_rng_state; Tensor *input_Bias, *input_rng_state;
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
input_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]); input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
input_Bias = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[2]); input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
} else { } else {
input_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]); input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
} }
fused_attn_arbitrary_seqlen_bwd( fused_attn_arbitrary_seqlen_bwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, attn_scale, dropout, b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, attn_scale, dropout,
...@@ -991,9 +991,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -991,9 +991,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
#endif #endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
#if (CUDNN_VERSION >= 8900) #if (CUDNN_VERSION >= 8900)
const Tensor *input_M = reinterpret_cast<const Tensor *>(Aux_CTX_Tensors->tensors[0]); const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
const Tensor *input_ZInv = reinterpret_cast<const Tensor *>(Aux_CTX_Tensors->tensors[1]); const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(Aux_CTX_Tensors->tensors[2]); const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout, fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout,
qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, input_O, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, input_O,
input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ, input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ,
......
...@@ -990,7 +990,7 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( ...@@ -990,7 +990,7 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
const auto cudnn_runtime_version = cudnnGetVersion(); const auto cudnn_runtime_version = cudnnGetVersion();
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Aux_CTX_Tensors->size = 3; Aux_CTX_Tensors->size = 3;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr; output_S->data.dptr = nullptr;
if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens, num_attn_heads, 1}; output_S->data.shape = {max_tokens, num_attn_heads, 1};
...@@ -998,17 +998,17 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( ...@@ -998,17 +998,17 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1}; output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1};
} }
output_S->data.dtype = DType::kFloat32; output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]); Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr; output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2}; output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64; output_rng_state->data.dtype = DType::kInt64;
Tensor *output_bias = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[2]); Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
output_bias->data.dptr = nullptr; output_bias->data.dptr = nullptr;
output_bias->data.shape = {bias_b, bias_h, max_seqlen, max_seqlen}; output_bias->data.shape = {bias_b, bias_h, max_seqlen, max_seqlen};
output_bias->data.dtype = QKV_type; output_bias->data.dtype = QKV_type;
} else { } else {
Aux_CTX_Tensors->size = 2; Aux_CTX_Tensors->size = 2;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr; output_S->data.dptr = nullptr;
if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens, num_attn_heads, 1}; output_S->data.shape = {max_tokens, num_attn_heads, 1};
...@@ -1016,22 +1016,22 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( ...@@ -1016,22 +1016,22 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1}; output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1};
} }
output_S->data.dtype = DType::kFloat32; output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]); Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr; output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2}; output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64; output_rng_state->data.dtype = DType::kInt64;
} }
} else if (Aux_CTX_Tensors->size == 2) { } else if (Aux_CTX_Tensors->size == 2) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr; devPtrS = output_S->data.dptr;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]); Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = rng_state->data.dptr; output_rng_state->data.dptr = rng_state->data.dptr;
} else if (Aux_CTX_Tensors->size == 3) { } else if (Aux_CTX_Tensors->size == 3) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr; devPtrS = output_S->data.dptr;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]); Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = rng_state->data.dptr; output_rng_state->data.dptr = rng_state->data.dptr;
Tensor *output_bias = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[2]); Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
output_bias->data.dptr = devPtrBias; output_bias->data.dptr = devPtrBias;
} else { } else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
...@@ -1216,7 +1216,7 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( ...@@ -1216,7 +1216,7 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
const auto cudnn_runtime_version = cudnnGetVersion(); const auto cudnn_runtime_version = cudnnGetVersion();
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Aux_CTX_Tensors->size = 3; Aux_CTX_Tensors->size = 3;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr; output_S->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; output_S->data.shape = {max_tokens_q, num_attn_heads, 1};
...@@ -1224,17 +1224,17 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( ...@@ -1224,17 +1224,17 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
} }
output_S->data.dtype = DType::kFloat32; output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]); Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr; output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2}; output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64; output_rng_state->data.dtype = DType::kInt64;
Tensor *output_bias = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[2]); Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
output_bias->data.dptr = nullptr; output_bias->data.dptr = nullptr;
output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv}; output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv};
output_bias->data.dtype = QKV_type; output_bias->data.dtype = QKV_type;
} else { } else {
Aux_CTX_Tensors->size = 2; Aux_CTX_Tensors->size = 2;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr; output_S->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; output_S->data.shape = {max_tokens_q, num_attn_heads, 1};
...@@ -1242,22 +1242,22 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( ...@@ -1242,22 +1242,22 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
} }
output_S->data.dtype = DType::kFloat32; output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]); Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr; output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2}; output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64; output_rng_state->data.dtype = DType::kInt64;
} }
} else if (Aux_CTX_Tensors->size == 2) { } else if (Aux_CTX_Tensors->size == 2) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr; devPtrS = output_S->data.dptr;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]); Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = rng_state->data.dptr; output_rng_state->data.dptr = rng_state->data.dptr;
} else if (Aux_CTX_Tensors->size == 3) { } else if (Aux_CTX_Tensors->size == 3) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr; devPtrS = output_S->data.dptr;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]); Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = rng_state->data.dptr; output_rng_state->data.dptr = rng_state->data.dptr;
Tensor *output_bias = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[2]); Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
output_bias->data.dptr = devPtrBias; output_bias->data.dptr = devPtrBias;
} else { } else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
...@@ -1446,7 +1446,7 @@ void fused_attn_arbitrary_seqlen_fwd( ...@@ -1446,7 +1446,7 @@ void fused_attn_arbitrary_seqlen_fwd(
const auto cudnn_runtime_version = cudnnGetVersion(); const auto cudnn_runtime_version = cudnnGetVersion();
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Aux_CTX_Tensors->size = 3; Aux_CTX_Tensors->size = 3;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr; output_S->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; output_S->data.shape = {max_tokens_q, num_attn_heads, 1};
...@@ -1454,17 +1454,17 @@ void fused_attn_arbitrary_seqlen_fwd( ...@@ -1454,17 +1454,17 @@ void fused_attn_arbitrary_seqlen_fwd(
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
} }
output_S->data.dtype = DType::kFloat32; output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]); Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr; output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2}; output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64; output_rng_state->data.dtype = DType::kInt64;
Tensor *output_bias = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[2]); Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
output_bias->data.dptr = nullptr; output_bias->data.dptr = nullptr;
output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv}; output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv};
output_bias->data.dtype = QKV_type; output_bias->data.dtype = QKV_type;
} else { } else {
Aux_CTX_Tensors->size = 2; Aux_CTX_Tensors->size = 2;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr; output_S->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; output_S->data.shape = {max_tokens_q, num_attn_heads, 1};
...@@ -1472,22 +1472,22 @@ void fused_attn_arbitrary_seqlen_fwd( ...@@ -1472,22 +1472,22 @@ void fused_attn_arbitrary_seqlen_fwd(
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
} }
output_S->data.dtype = DType::kFloat32; output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]); Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr; output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2}; output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64; output_rng_state->data.dtype = DType::kInt64;
} }
} else if (Aux_CTX_Tensors->size == 2) { } else if (Aux_CTX_Tensors->size == 2) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr; devPtrS = output_S->data.dptr;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]); Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = rng_state->data.dptr; output_rng_state->data.dptr = rng_state->data.dptr;
} else if (Aux_CTX_Tensors->size == 3) { } else if (Aux_CTX_Tensors->size == 3) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr; devPtrS = output_S->data.dptr;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]); Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = rng_state->data.dptr; output_rng_state->data.dptr = rng_state->data.dptr;
Tensor *output_bias = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[2]); Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
output_bias->data.dptr = devPtrBias; output_bias->data.dptr = devPtrBias;
} else { } else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
......
...@@ -1239,12 +1239,12 @@ void fused_attn_max_512_fwd_qkvpacked( ...@@ -1239,12 +1239,12 @@ void fused_attn_max_512_fwd_qkvpacked(
if (Aux_CTX_Tensors->size == 0) { if (Aux_CTX_Tensors->size == 0) {
Aux_CTX_Tensors->size = 1; Aux_CTX_Tensors->size = 1;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr; output_S->data.dptr = nullptr;
output_S->data.shape = {batch, num_head, max_seqlen, max_seqlen}; output_S->data.shape = {batch, num_head, max_seqlen, max_seqlen};
output_S->data.dtype = input_QKV->data.dtype; output_S->data.dtype = input_QKV->data.dtype;
} else if (Aux_CTX_Tensors->size == 1) { } else if (Aux_CTX_Tensors->size == 1) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr; devPtrS = output_S->data.dptr;
} else { } else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
...@@ -1317,12 +1317,12 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t num_head, size_t q_max ...@@ -1317,12 +1317,12 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t num_head, size_t q_max
if (Aux_CTX_Tensors->size == 0) { if (Aux_CTX_Tensors->size == 0) {
Aux_CTX_Tensors->size = 1; Aux_CTX_Tensors->size = 1;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr; output_S->data.dptr = nullptr;
output_S->data.shape = {batch, num_head, q_max_seqlen, kv_max_seqlen}; output_S->data.shape = {batch, num_head, q_max_seqlen, kv_max_seqlen};
output_S->data.dtype = q_type; output_S->data.dtype = q_type;
} else if (Aux_CTX_Tensors->size == 1) { } else if (Aux_CTX_Tensors->size == 1) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr; devPtrS = output_S->data.dptr;
} else { } else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
...@@ -1386,12 +1386,12 @@ void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen, ...@@ -1386,12 +1386,12 @@ void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen,
if (Aux_CTX_Tensors->size == 0) { if (Aux_CTX_Tensors->size == 0) {
Aux_CTX_Tensors->size = 1; Aux_CTX_Tensors->size = 1;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr; output_S->data.dptr = nullptr;
output_S->data.shape = {batch, num_head, q_max_seqlen, kv_max_seqlen}; output_S->data.shape = {batch, num_head, q_max_seqlen, kv_max_seqlen};
output_S->data.dtype = q_type; output_S->data.dtype = q_type;
} else if (Aux_CTX_Tensors->size == 1) { } else if (Aux_CTX_Tensors->size == 1) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]); Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr; devPtrS = output_S->data.dptr;
} else { } else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
......
...@@ -2383,9 +2383,9 @@ void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t ma ...@@ -2383,9 +2383,9 @@ void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t ma
void* devPtrZInv = nullptr; void* devPtrZInv = nullptr;
if (Aux_CTX_Tensors->size == 0) { if (Aux_CTX_Tensors->size == 0) {
Aux_CTX_Tensors->size = 3; Aux_CTX_Tensors->size = 3;
Tensor* output_M = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[0]); Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Tensor* output_ZInv = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]); Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
Tensor* output_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[2]); Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
output_M->data.dptr = nullptr; output_M->data.dptr = nullptr;
output_M->data.shape = {batch, num_attn_heads, max_seqlen, 1}; output_M->data.shape = {batch, num_attn_heads, max_seqlen, 1};
output_M->data.dtype = DType::kFloat32; output_M->data.dtype = DType::kFloat32;
...@@ -2396,9 +2396,9 @@ void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t ma ...@@ -2396,9 +2396,9 @@ void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t ma
output_rng_state->data.shape = {2}; output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64; output_rng_state->data.dtype = DType::kInt64;
} else if (Aux_CTX_Tensors->size == 3) { } else if (Aux_CTX_Tensors->size == 3) {
Tensor* output_M = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[0]); Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Tensor* output_ZInv = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]); Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
Tensor* output_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[2]); Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
devPtrM = output_M->data.dptr; devPtrM = output_M->data.dptr;
devPtrZInv = output_ZInv->data.dptr; devPtrZInv = output_ZInv->data.dptr;
output_rng_state->data.dptr = rng_state->data.dptr; output_rng_state->data.dptr = rng_state->data.dptr;
...@@ -2582,9 +2582,9 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num ...@@ -2582,9 +2582,9 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num
void* devPtrZInv = nullptr; void* devPtrZInv = nullptr;
if (Aux_CTX_Tensors->size == 0) { if (Aux_CTX_Tensors->size == 0) {
Aux_CTX_Tensors->size = 3; Aux_CTX_Tensors->size = 3;
Tensor* output_M = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[0]); Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Tensor* output_ZInv = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]); Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
Tensor* output_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[2]); Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
output_M->data.dptr = nullptr; output_M->data.dptr = nullptr;
output_M->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; output_M->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
output_M->data.dtype = DType::kFloat32; output_M->data.dtype = DType::kFloat32;
...@@ -2595,9 +2595,9 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num ...@@ -2595,9 +2595,9 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num
output_rng_state->data.shape = {2}; output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64; output_rng_state->data.dtype = DType::kInt64;
} else if (Aux_CTX_Tensors->size == 3) { } else if (Aux_CTX_Tensors->size == 3) {
Tensor* output_M = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[0]); Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Tensor* output_ZInv = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]); Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
Tensor* output_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[2]); Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
devPtrM = output_M->data.dptr; devPtrM = output_M->data.dptr;
devPtrZInv = output_ZInv->data.dptr; devPtrZInv = output_ZInv->data.dptr;
output_rng_state->data.dptr = rng_state->data.dptr; output_rng_state->data.dptr = rng_state->data.dptr;
...@@ -2779,9 +2779,9 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou ...@@ -2779,9 +2779,9 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou
void* devPtrZInv = nullptr; void* devPtrZInv = nullptr;
if (Aux_CTX_Tensors->size == 0) { if (Aux_CTX_Tensors->size == 0) {
Aux_CTX_Tensors->size = 3; Aux_CTX_Tensors->size = 3;
Tensor* output_M = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[0]); Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Tensor* output_ZInv = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]); Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
Tensor* output_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[2]); Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
output_M->data.dptr = nullptr; output_M->data.dptr = nullptr;
output_M->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; output_M->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
output_M->data.dtype = DType::kFloat32; output_M->data.dtype = DType::kFloat32;
...@@ -2792,9 +2792,9 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou ...@@ -2792,9 +2792,9 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou
output_rng_state->data.shape = {2}; output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64; output_rng_state->data.dtype = DType::kInt64;
} else if (Aux_CTX_Tensors->size == 3) { } else if (Aux_CTX_Tensors->size == 3) {
Tensor* output_M = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[0]); Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Tensor* output_ZInv = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]); Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
Tensor* output_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[2]); Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
devPtrM = output_M->data.dptr; devPtrM = output_M->data.dptr;
devPtrZInv = output_ZInv->data.dptr; devPtrZInv = output_ZInv->data.dptr;
output_rng_state->data.dptr = rng_state->data.dptr; output_rng_state->data.dptr = rng_state->data.dptr;
......
...@@ -260,12 +260,12 @@ void nvte_copy_to_kv_cache(NVTETensor new_k, NVTETensor new_v, NVTETensor k_cach ...@@ -260,12 +260,12 @@ void nvte_copy_to_kv_cache(NVTETensor new_k, NVTETensor new_v, NVTETensor k_cach
NVTE_API_CALL(nvte_copy_to_kv_cache); NVTE_API_CALL(nvte_copy_to_kv_cache);
using namespace transformer_engine; using namespace transformer_engine;
kv_cache::copy_to_kv_cache( kv_cache::copy_to_kv_cache(*convertNVTETensorCheck(new_k), *convertNVTETensorCheck(new_v),
*reinterpret_cast<Tensor *>(new_k), *reinterpret_cast<Tensor *>(new_v), *convertNVTETensorCheck(k_cache), *convertNVTETensorCheck(v_cache),
*reinterpret_cast<Tensor *>(k_cache), *reinterpret_cast<Tensor *>(v_cache), *convertNVTETensorCheck(page_table),
*reinterpret_cast<Tensor *>(page_table), *reinterpret_cast<Tensor *>(cu_new_lens), *convertNVTETensorCheck(cu_new_lens),
*reinterpret_cast<Tensor *>(cu_cached_lens), qkv_format, b, max_ctx_len, max_seq_len, *convertNVTETensorCheck(cu_cached_lens), qkv_format, b, max_ctx_len,
max_pages_per_seq, is_non_paged, stream); max_seq_len, max_pages_per_seq, is_non_paged, stream);
} }
/*************************************************************************************************** /***************************************************************************************************
...@@ -277,9 +277,9 @@ void nvte_convert_thd_to_bshd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETens ...@@ -277,9 +277,9 @@ void nvte_convert_thd_to_bshd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETens
NVTE_API_CALL(nvte_convert_thd_to_bshd); NVTE_API_CALL(nvte_convert_thd_to_bshd);
using namespace transformer_engine; using namespace transformer_engine;
kv_cache::convert_thd_to_bshd(*reinterpret_cast<Tensor *>(tensor), kv_cache::convert_thd_to_bshd(*convertNVTETensorCheck(tensor),
*reinterpret_cast<Tensor *>(cu_seqlens), *convertNVTETensorCheck(cu_seqlens),
*reinterpret_cast<Tensor *>(new_tensor), b, max_seq_len, stream); *convertNVTETensorCheck(new_tensor), b, max_seq_len, stream);
} }
/*************************************************************************************************** /***************************************************************************************************
...@@ -291,7 +291,7 @@ void nvte_convert_bshd_to_thd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETens ...@@ -291,7 +291,7 @@ void nvte_convert_bshd_to_thd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETens
NVTE_API_CALL(nvte_convert_bshd_to_thd); NVTE_API_CALL(nvte_convert_bshd_to_thd);
using namespace transformer_engine; using namespace transformer_engine;
kv_cache::convert_bshd_to_thd(*reinterpret_cast<Tensor *>(tensor), kv_cache::convert_bshd_to_thd(*convertNVTETensorCheck(tensor),
*reinterpret_cast<Tensor *>(cu_seqlens), *convertNVTETensorCheck(cu_seqlens),
*reinterpret_cast<Tensor *>(new_tensor), t, stream); *convertNVTETensorCheck(new_tensor), t, stream);
} }
...@@ -308,11 +308,10 @@ void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens ...@@ -308,11 +308,10 @@ void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens
const int stride_d, cudaStream_t stream) { const int stride_d, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_rope_forward); NVTE_API_CALL(nvte_fused_rope_forward);
using namespace transformer_engine; using namespace transformer_engine;
fused_rope_forward( fused_rope_forward(*convertNVTETensorCheck(input), *convertNVTETensorCheck(cu_seqlens),
*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(cu_seqlens), *convertNVTETensorCheck(freqs), *convertNVTETensorCheck(start_positions),
*reinterpret_cast<const Tensor *>(freqs), *reinterpret_cast<const Tensor *>(start_positions), convertNVTETensorCheck(output), qkv_format, interleaved, cp_size, cp_rank, s,
reinterpret_cast<Tensor *>(output), qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, b, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, stream);
stride_s_or_t, stride_b, stride_h, stride_d, stream);
} }
void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens, void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens,
...@@ -324,9 +323,8 @@ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu ...@@ -324,9 +323,8 @@ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_rope_backward); NVTE_API_CALL(nvte_fused_rope_backward);
using namespace transformer_engine; using namespace transformer_engine;
fused_rope_backward(*reinterpret_cast<const Tensor *>(output_grads), fused_rope_backward(*convertNVTETensorCheck(output_grads), *convertNVTETensorCheck(cu_seqlens),
*reinterpret_cast<const Tensor *>(cu_seqlens), *convertNVTETensorCheck(freqs), convertNVTETensorCheck(input_grads),
*reinterpret_cast<const Tensor *>(freqs), qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t,
reinterpret_cast<Tensor *>(input_grads), qkv_format, interleaved, cp_size, stride_b, stride_h, stride_d, stream);
cp_rank, s, b, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, stream);
} }
...@@ -551,8 +551,8 @@ void nvte_scaled_aligned_causal_masked_softmax_forward(const NVTETensor input, ...@@ -551,8 +551,8 @@ void nvte_scaled_aligned_causal_masked_softmax_forward(const NVTETensor input,
float scale_factor, cudaStream_t stream) { float scale_factor, cudaStream_t stream) {
NVTE_API_CALL(nvte_scaled_aligned_causal_masked_softmax_forward); NVTE_API_CALL(nvte_scaled_aligned_causal_masked_softmax_forward);
using namespace transformer_engine; using namespace transformer_engine;
scaled_aligned_causal_masked_softmax_forward(*reinterpret_cast<const Tensor *>(input), scaled_aligned_causal_masked_softmax_forward(*convertNVTETensorCheck(input),
reinterpret_cast<Tensor *>(softmax_results), convertNVTETensorCheck(softmax_results),
scale_factor, stream); scale_factor, stream);
} }
...@@ -563,6 +563,6 @@ void nvte_scaled_aligned_causal_masked_softmax_backward(const NVTETensor incomin ...@@ -563,6 +563,6 @@ void nvte_scaled_aligned_causal_masked_softmax_backward(const NVTETensor incomin
NVTE_API_CALL(nvte_scaled_aligned_causal_masked_softmax_backward); NVTE_API_CALL(nvte_scaled_aligned_causal_masked_softmax_backward);
using namespace transformer_engine; using namespace transformer_engine;
scaled_aligned_causal_masked_softmax_backward( scaled_aligned_causal_masked_softmax_backward(
*reinterpret_cast<Tensor *>(output_grads), *reinterpret_cast<const Tensor *>(incoming_grads), *convertNVTETensorCheck(output_grads), *convertNVTETensorCheck(incoming_grads),
*reinterpret_cast<const Tensor *>(softmax_results), scale_factor, stream); *convertNVTETensorCheck(softmax_results), scale_factor, stream);
} }
...@@ -815,8 +815,8 @@ void nvte_scaled_softmax_forward(const NVTETensor input, NVTETensor softmax_resu ...@@ -815,8 +815,8 @@ void nvte_scaled_softmax_forward(const NVTETensor input, NVTETensor softmax_resu
float scale_factor, cudaStream_t stream) { float scale_factor, cudaStream_t stream) {
NVTE_API_CALL(nvte_scaled_softmax_forward); NVTE_API_CALL(nvte_scaled_softmax_forward);
using namespace transformer_engine; using namespace transformer_engine;
scaled_softmax_forward(*reinterpret_cast<const Tensor *>(input), scaled_softmax_forward(*convertNVTETensorCheck(input), convertNVTETensorCheck(softmax_results),
reinterpret_cast<Tensor *>(softmax_results), scale_factor, stream); scale_factor, stream);
} }
void nvte_scaled_softmax_backward(const NVTETensor incoming_grads, const NVTETensor softmax_results, void nvte_scaled_softmax_backward(const NVTETensor incoming_grads, const NVTETensor softmax_results,
...@@ -824,9 +824,9 @@ void nvte_scaled_softmax_backward(const NVTETensor incoming_grads, const NVTETen ...@@ -824,9 +824,9 @@ void nvte_scaled_softmax_backward(const NVTETensor incoming_grads, const NVTETen
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_scaled_softmax_backward); NVTE_API_CALL(nvte_scaled_softmax_backward);
using namespace transformer_engine; using namespace transformer_engine;
scaled_softmax_backward(*reinterpret_cast<Tensor *>(output_grads), scaled_softmax_backward(*convertNVTETensorCheck(output_grads),
*reinterpret_cast<const Tensor *>(incoming_grads), *convertNVTETensorCheck(incoming_grads),
*reinterpret_cast<const Tensor *>(softmax_results), scale_factor, stream); *convertNVTETensorCheck(softmax_results), scale_factor, stream);
} }
void nvte_scaled_masked_softmax_forward(const NVTETensor input, const NVTETensor mask, void nvte_scaled_masked_softmax_forward(const NVTETensor input, const NVTETensor mask,
...@@ -834,9 +834,8 @@ void nvte_scaled_masked_softmax_forward(const NVTETensor input, const NVTETensor ...@@ -834,9 +834,8 @@ void nvte_scaled_masked_softmax_forward(const NVTETensor input, const NVTETensor
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_scaled_masked_softmax_forward); NVTE_API_CALL(nvte_scaled_masked_softmax_forward);
using namespace transformer_engine; using namespace transformer_engine;
scaled_masked_softmax_forward(*reinterpret_cast<const Tensor *>(input), scaled_masked_softmax_forward(*convertNVTETensorCheck(input), *convertNVTETensorCheck(mask),
*reinterpret_cast<const Tensor *>(mask), convertNVTETensorCheck(softmax_results), scale_factor, stream);
reinterpret_cast<Tensor *>(softmax_results), scale_factor, stream);
} }
void nvte_scaled_masked_softmax_backward(const NVTETensor incoming_grads, void nvte_scaled_masked_softmax_backward(const NVTETensor incoming_grads,
...@@ -844,7 +843,7 @@ void nvte_scaled_masked_softmax_backward(const NVTETensor incoming_grads, ...@@ -844,7 +843,7 @@ void nvte_scaled_masked_softmax_backward(const NVTETensor incoming_grads,
float scale_factor, cudaStream_t stream) { float scale_factor, cudaStream_t stream) {
NVTE_API_CALL(nvte_scaled_masked_softmax_backward); NVTE_API_CALL(nvte_scaled_masked_softmax_backward);
using namespace transformer_engine; using namespace transformer_engine;
scaled_masked_softmax_backward( scaled_masked_softmax_backward(*convertNVTETensorCheck(output_grads),
*reinterpret_cast<Tensor *>(output_grads), *reinterpret_cast<const Tensor *>(incoming_grads), *convertNVTETensorCheck(incoming_grads),
*reinterpret_cast<const Tensor *>(softmax_results), scale_factor, stream); *convertNVTETensorCheck(softmax_results), scale_factor, stream);
} }
...@@ -599,9 +599,9 @@ void nvte_scaled_upper_triang_masked_softmax_forward(const NVTETensor input, ...@@ -599,9 +599,9 @@ void nvte_scaled_upper_triang_masked_softmax_forward(const NVTETensor input,
NVTETensor softmax_results, float scale_factor, NVTETensor softmax_results, float scale_factor,
cudaStream_t stream) { cudaStream_t stream) {
using namespace transformer_engine; using namespace transformer_engine;
scaled_upper_triang_masked_softmax_forward(*reinterpret_cast<const Tensor *>(input), scaled_upper_triang_masked_softmax_forward(*convertNVTETensorCheck(input),
reinterpret_cast<Tensor *>(softmax_results), convertNVTETensorCheck(softmax_results), scale_factor,
scale_factor, stream); stream);
} }
void nvte_scaled_upper_triang_masked_softmax_backward(const NVTETensor incoming_grads, void nvte_scaled_upper_triang_masked_softmax_backward(const NVTETensor incoming_grads,
...@@ -610,6 +610,6 @@ void nvte_scaled_upper_triang_masked_softmax_backward(const NVTETensor incoming_ ...@@ -610,6 +610,6 @@ void nvte_scaled_upper_triang_masked_softmax_backward(const NVTETensor incoming_
cudaStream_t stream) { cudaStream_t stream) {
using namespace transformer_engine; using namespace transformer_engine;
scaled_upper_triang_masked_softmax_backward( scaled_upper_triang_masked_softmax_backward(
*reinterpret_cast<Tensor *>(output_grads), *reinterpret_cast<const Tensor *>(incoming_grads), *convertNVTETensorCheck(output_grads), *convertNVTETensorCheck(incoming_grads),
*reinterpret_cast<const Tensor *>(softmax_results), scale_factor, stream); *convertNVTETensorCheck(softmax_results), scale_factor, stream);
} }
...@@ -588,12 +588,12 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons ...@@ -588,12 +588,12 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
int math_sm_count, cudaStream_t stream) { int math_sm_count, cudaStream_t stream) {
NVTE_API_CALL(nvte_cublas_gemm); NVTE_API_CALL(nvte_cublas_gemm);
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *inputA = reinterpret_cast<const Tensor *>(A); const Tensor *inputA = convertNVTETensorCheck(A);
const Tensor *inputB = reinterpret_cast<const Tensor *>(B); const Tensor *inputB = convertNVTETensorCheck(B);
Tensor *outputD = reinterpret_cast<Tensor *>(D); Tensor *outputD = convertNVTETensor(D);
const Tensor *biasTensor = reinterpret_cast<const Tensor *>(bias); const Tensor *biasTensor = convertNVTETensor(bias);
Tensor *outputGelu = reinterpret_cast<Tensor *>(pre_gelu_out); Tensor *outputGelu = convertNVTETensor(pre_gelu_out);
Tensor *wspace = reinterpret_cast<Tensor *>(workspace); Tensor *wspace = convertNVTETensor(workspace);
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
(transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0],
...@@ -616,13 +616,13 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor ...@@ -616,13 +616,13 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
"Cublas version >=12.2.5 and <13.0 is required for atomic gemm."); "Cublas version >=12.2.5 and <13.0 is required for atomic gemm.");
using namespace transformer_engine; using namespace transformer_engine;
const Tensor *inputA = reinterpret_cast<const Tensor *>(A); const Tensor *inputA = convertNVTETensorCheck(A);
const Tensor *inputB = reinterpret_cast<const Tensor *>(B); const Tensor *inputB = convertNVTETensorCheck(B);
Tensor *outputD = reinterpret_cast<Tensor *>(D); Tensor *outputD = convertNVTETensor(D);
const Tensor *biasTensor = reinterpret_cast<const Tensor *>(bias); const Tensor *biasTensor = convertNVTETensor(bias);
Tensor *outputGelu = reinterpret_cast<Tensor *>(pre_gelu_out); Tensor *outputGelu = convertNVTETensor(pre_gelu_out);
const Tensor *inputCounter = reinterpret_cast<const Tensor *>(counter); const Tensor *inputCounter = convertNVTETensor(counter);
Tensor *wspace = reinterpret_cast<Tensor *>(workspace); Tensor *wspace = convertNVTETensor(workspace);
NVTE_CHECK(is_delayed_tensor_scaling(inputA->scaling_mode) && NVTE_CHECK(is_delayed_tensor_scaling(inputA->scaling_mode) &&
is_delayed_tensor_scaling(inputB->scaling_mode), is_delayed_tensor_scaling(inputB->scaling_mode),
......
...@@ -806,7 +806,7 @@ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETenso ...@@ -806,7 +806,7 @@ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETenso
using namespace transformer_engine; using namespace transformer_engine;
multi_tensor_adam::multi_tensor_adam_cuda( multi_tensor_adam::multi_tensor_adam_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag), chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2, convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2,
epsilon, step, mode, bias_correction, weight_decay, device_id, stream); epsilon, step, mode, bias_correction, weight_decay, device_id, stream);
} }
...@@ -820,7 +820,7 @@ void nvte_multi_tensor_adam_param_remainder_cuda( ...@@ -820,7 +820,7 @@ void nvte_multi_tensor_adam_param_remainder_cuda(
using namespace transformer_engine; using namespace transformer_engine;
multi_tensor_adam::multi_tensor_adam_param_remainder_cuda( multi_tensor_adam::multi_tensor_adam_param_remainder_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag), chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2, convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2,
epsilon, step, mode, bias_correction, weight_decay, device_id, stream); epsilon, step, mode, bias_correction, weight_decay, device_id, stream);
} }
...@@ -836,7 +836,7 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag, ...@@ -836,7 +836,7 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag,
using namespace transformer_engine; using namespace transformer_engine;
multi_tensor_adam::multi_tensor_adam_fp8_cuda( multi_tensor_adam::multi_tensor_adam_fp8_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag), chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2, convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2,
epsilon, step, mode, bias_correction, weight_decay, static_cast<DType>(fp8_dtype), device_id, epsilon, step, mode, bias_correction, weight_decay, static_cast<DType>(fp8_dtype), device_id,
stream); stream);
...@@ -851,11 +851,10 @@ void nvte_multi_tensor_adam_capturable_cuda( ...@@ -851,11 +851,10 @@ void nvte_multi_tensor_adam_capturable_cuda(
using namespace transformer_engine; using namespace transformer_engine;
multi_tensor_adam::multi_tensor_adam_capturable_cuda( multi_tensor_adam::multi_tensor_adam_capturable_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag), chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
*reinterpret_cast<Tensor *>(lr), beta1, beta2, epsilon, *reinterpret_cast<Tensor *>(step), *convertNVTETensorCheck(lr), beta1, beta2, epsilon, *convertNVTETensorCheck(step), mode,
mode, bias_correction, weight_decay, *reinterpret_cast<Tensor *>(inv_scale), device_id, bias_correction, weight_decay, *convertNVTETensorCheck(inv_scale), device_id, stream);
stream);
} }
void nvte_multi_tensor_adam_capturable_master_cuda( void nvte_multi_tensor_adam_capturable_master_cuda(
...@@ -867,9 +866,8 @@ void nvte_multi_tensor_adam_capturable_master_cuda( ...@@ -867,9 +866,8 @@ void nvte_multi_tensor_adam_capturable_master_cuda(
using namespace transformer_engine; using namespace transformer_engine;
multi_tensor_adam::multi_tensor_adam_capturable_master_cuda( multi_tensor_adam::multi_tensor_adam_capturable_master_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag), chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
*reinterpret_cast<Tensor *>(lr), beta1, beta2, epsilon, *reinterpret_cast<Tensor *>(step), *convertNVTETensorCheck(lr), beta1, beta2, epsilon, *convertNVTETensorCheck(step), mode,
mode, bias_correction, weight_decay, *reinterpret_cast<Tensor *>(inv_scale), device_id, bias_correction, weight_decay, *convertNVTETensorCheck(inv_scale), device_id, stream);
stream);
} }
...@@ -77,7 +77,7 @@ void nvte_multi_tensor_compute_scale_and_scale_inv_cuda( ...@@ -77,7 +77,7 @@ void nvte_multi_tensor_compute_scale_and_scale_inv_cuda(
using namespace transformer_engine; using namespace transformer_engine;
multi_tensor_compute_scale::multi_tensor_compute_scale_and_scale_inv_cuda( multi_tensor_compute_scale::multi_tensor_compute_scale_and_scale_inv_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag), chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), max_fp8, convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), max_fp8,
force_pow_2_scales, epsilon, device_id, stream); force_pow_2_scales, epsilon, device_id, stream);
} }
...@@ -459,10 +459,10 @@ void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETen ...@@ -459,10 +459,10 @@ void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETen
using namespace transformer_engine; using namespace transformer_engine;
multi_tensor_l2norm::multi_tensor_l2norm_cuda( multi_tensor_l2norm::multi_tensor_l2norm_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag), chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
*reinterpret_cast<Tensor *>(output), *reinterpret_cast<Tensor *>(output_per_tensor), *convertNVTETensorCheck(output), *convertNVTETensorCheck(output_per_tensor),
*reinterpret_cast<Tensor *>(ret), *reinterpret_cast<Tensor *>(ret_per_tensor), per_tensor, *convertNVTETensorCheck(ret), *convertNVTETensorCheck(ret_per_tensor), per_tensor,
max_chunks_per_tensor, device_id, stream); max_chunks_per_tensor, device_id, stream);
} }
...@@ -477,9 +477,9 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag, ...@@ -477,9 +477,9 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
using namespace transformer_engine; using namespace transformer_engine;
multi_tensor_l2norm::multi_tensor_unscale_l2norm_cuda( multi_tensor_l2norm::multi_tensor_unscale_l2norm_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag), chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
*reinterpret_cast<Tensor *>(output), *reinterpret_cast<Tensor *>(output_per_tensor), *convertNVTETensorCheck(output), *convertNVTETensorCheck(output_per_tensor),
*reinterpret_cast<Tensor *>(ret), *reinterpret_cast<Tensor *>(ret_per_tensor), *convertNVTETensorCheck(ret), *convertNVTETensorCheck(ret_per_tensor),
*reinterpret_cast<Tensor *>(inv_scale), per_tensor, max_chunks_per_tensor, device_id, stream); *convertNVTETensorCheck(inv_scale), per_tensor, max_chunks_per_tensor, device_id, stream);
} }
...@@ -124,7 +124,7 @@ void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETens ...@@ -124,7 +124,7 @@ void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETens
using namespace transformer_engine; using namespace transformer_engine;
multi_tensor_scale::multi_tensor_scale_cuda( multi_tensor_scale::multi_tensor_scale_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag), chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), scale, device_id, convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), scale, device_id,
stream); stream);
} }
...@@ -196,7 +196,7 @@ void nvte_multi_tensor_sgd_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor ...@@ -196,7 +196,7 @@ void nvte_multi_tensor_sgd_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor
using namespace transformer_engine; using namespace transformer_engine;
multi_tensor_sgd::multi_tensor_sgd_cuda( multi_tensor_sgd::multi_tensor_sgd_cuda(
chunk_size, *reinterpret_cast<Tensor*>(noop_flag), chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), wd, momentum, convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), wd, momentum,
dampening, lr, nesterov, first_run, wd_after_momentum, scale, device_id, stream); dampening, lr, nesterov, first_run, wd_after_momentum, scale, device_id, stream);
} }
...@@ -105,11 +105,11 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size ...@@ -105,11 +105,11 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
// Compute FP8 transpose if required // Compute FP8 transpose if required
if (z->has_columnwise_data() && is_tensor_scaling(z->scaling_mode)) { if (z->has_columnwise_data() && is_tensor_scaling(z->scaling_mode)) {
Tensor transpose_data; NVTETensor transpose_data = nvte_create_tensor(z->scaling_mode);
transpose_data.data = z->columnwise_data; Tensor& t = *convertNVTETensor(transpose_data);
transpose_data.scaling_mode = z->scaling_mode; t.data = z->columnwise_data;
nvte_transpose(reinterpret_cast<NVTETensor>(z), reinterpret_cast<NVTETensor>(&transpose_data), nvte_transpose(static_cast<NVTETensor>(*z), transpose_data, stream);
stream); nvte_destroy_tensor(transpose_data);
} }
return; return;
...@@ -195,11 +195,10 @@ void nvte_layernorm_fwd(const NVTETensor x, // BxSxhidden_size ...@@ -195,11 +195,10 @@ void nvte_layernorm_fwd(const NVTETensor x, // BxSxhidden_size
const bool zero_centered_gamma, cudaStream_t stream) { const bool zero_centered_gamma, cudaStream_t stream) {
NVTE_API_CALL(nvte_layernorm_fwd); NVTE_API_CALL(nvte_layernorm_fwd);
using namespace transformer_engine; using namespace transformer_engine;
layernorm_fwd(*reinterpret_cast<const Tensor*>(x), *reinterpret_cast<const Tensor*>(gamma), layernorm_fwd(*convertNVTETensorCheck(x), *convertNVTETensorCheck(gamma),
*reinterpret_cast<const Tensor*>(beta), epsilon, reinterpret_cast<Tensor*>(z), *convertNVTETensorCheck(beta), epsilon, convertNVTETensor(z), convertNVTETensor(mu),
reinterpret_cast<Tensor*>(mu), reinterpret_cast<Tensor*>(rsigma), convertNVTETensor(rsigma), convertNVTETensor(workspace), multiprocessorCount,
reinterpret_cast<Tensor*>(workspace), multiprocessorCount, zero_centered_gamma, zero_centered_gamma, stream);
stream);
} }
void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size
...@@ -212,10 +211,9 @@ void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size ...@@ -212,10 +211,9 @@ void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_layernorm_bwd); NVTE_API_CALL(nvte_layernorm_bwd);
using namespace transformer_engine; using namespace transformer_engine;
layernorm_bwd(*reinterpret_cast<const Tensor*>(dz), *reinterpret_cast<const Tensor*>(x), layernorm_bwd(*convertNVTETensorCheck(dz), *convertNVTETensorCheck(x),
*reinterpret_cast<const Tensor*>(mu), *reinterpret_cast<const Tensor*>(rsigma), *convertNVTETensorCheck(mu), *convertNVTETensorCheck(rsigma),
*reinterpret_cast<const Tensor*>(gamma), reinterpret_cast<Tensor*>(dx), *convertNVTETensorCheck(gamma), convertNVTETensor(dx), convertNVTETensor(dgamma),
reinterpret_cast<Tensor*>(dgamma), reinterpret_cast<Tensor*>(dbeta), convertNVTETensor(dbeta), convertNVTETensor(workspace), multiprocessorCount,
reinterpret_cast<Tensor*>(workspace), multiprocessorCount, zero_centered_gamma, zero_centered_gamma, stream);
stream);
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment