"vscode:/vscode.git/clone" did not exist on "380170038e05cf81953c29d7e8ed789e048b6434"
Unverified Commit 4292653c authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Avoid memory allocations and deallocations when creating NVTETensor (#1813)



* Changed the Tensor allocation strategy
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fixes
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Disable debug flag
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix the double free error
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fixed pyTorch recipe extension
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Hide TensorAllocator and fix the usage in LayerNorm
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Cleaning
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix permutation
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

---------
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 41909dc8
......@@ -200,7 +200,7 @@ std::vector<std::vector<Tensor *>> convert_tensor_array(NVTETensor **nvte_tensor
for (size_t i = 0; i < outer_size; ++i) {
ret.emplace_back();
for (size_t j = 0; j < inner_size; ++j) {
ret.back().push_back(reinterpret_cast<Tensor *>(nvte_tensors[i][j]));
ret.back().push_back(convertNVTETensor(nvte_tensors[i][j]));
}
}
return ret;
......
......@@ -89,9 +89,16 @@ struct SimpleTensor {
}
return acc;
}
void clear() {
dptr = nullptr;
shape.resize(0);
dtype = DType::kFloat32;
}
};
struct Tensor {
public:
SimpleTensor data;
SimpleTensor columnwise_data;
SimpleTensor amax;
......@@ -99,8 +106,8 @@ struct Tensor {
SimpleTensor scale_inv;
SimpleTensor columnwise_scale_inv;
public:
NVTEScalingMode scaling_mode;
NVTETensor nvte_tensor;
Tensor()
: data(),
......@@ -109,7 +116,20 @@ struct Tensor {
scale(nullptr, {1}, DType::kFloat32),
scale_inv(nullptr, {1}, DType::kFloat32),
columnwise_scale_inv(nullptr, {1}, DType::kFloat32),
scaling_mode(NVTE_DELAYED_TENSOR_SCALING) {}
scaling_mode(NVTE_DELAYED_TENSOR_SCALING),
nvte_tensor(0) {}
void clear() {
data.clear();
columnwise_data.clear();
amax.clear();
scale.clear();
scale_inv.clear();
columnwise_scale_inv.clear();
scaling_mode = NVTE_DELAYED_TENSOR_SCALING;
}
explicit operator NVTETensor() const noexcept { return nvte_tensor; }
size_t numel() const {
size_t acc = 1;
......@@ -620,6 +640,8 @@ bool is_supported_by_CC_100();
std::vector<std::vector<Tensor *>> convert_tensor_array(NVTETensor **nvte_tensors,
size_t outer_size, size_t inner_size);
Tensor *convertNVTETensor(const NVTETensor tensor);
Tensor *convertNVTETensorCheck(const NVTETensor tensor);
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_COMMON_H_
......@@ -677,9 +677,9 @@ void nvte_cp_thd_read_half_tensor(const NVTETensor &tensor, const NVTETensor &cu
NVTE_API_CALL(nvte_thd_read_half_tensor);
using namespace transformer_engine;
context_parallel::thd_read_half_tensor(*reinterpret_cast<Tensor *>(tensor),
*reinterpret_cast<Tensor *>(cu_seqlens),
*reinterpret_cast<Tensor *>(half), half_idx, stream);
context_parallel::thd_read_half_tensor(*convertNVTETensorCheck(tensor),
*convertNVTETensorCheck(cu_seqlens),
*convertNVTETensorCheck(half), half_idx, stream);
}
void nvte_cp_thd_second_half_lse_correction(NVTETensor lse, const NVTETensor &lse_per_step,
......@@ -689,8 +689,8 @@ void nvte_cp_thd_second_half_lse_correction(NVTETensor lse, const NVTETensor &ls
using namespace transformer_engine;
context_parallel::thd_second_half_lse_correction(
*reinterpret_cast<Tensor *>(lse), *reinterpret_cast<Tensor *>(lse_per_step),
*reinterpret_cast<Tensor *>(cu_seqlens), lse_packed, stream);
*convertNVTETensorCheck(lse), *convertNVTETensorCheck(lse_per_step),
*convertNVTETensorCheck(cu_seqlens), lse_packed, stream);
}
void nvte_cp_thd_read_second_half_lse(const NVTETensor &lse, const NVTETensor &cu_seqlens,
......@@ -700,8 +700,8 @@ void nvte_cp_thd_read_second_half_lse(const NVTETensor &lse, const NVTETensor &c
using namespace transformer_engine;
context_parallel::thd_read_second_half_lse(
*reinterpret_cast<Tensor *>(lse), *reinterpret_cast<Tensor *>(cu_seqlens),
*reinterpret_cast<Tensor *>(half_lse), lse_packed, second_half_lse_seqlen, stream);
*convertNVTETensorCheck(lse), *convertNVTETensorCheck(cu_seqlens),
*convertNVTETensorCheck(half_lse), lse_packed, second_half_lse_seqlen, stream);
}
void nvte_cp_thd_out_correction(NVTETensor out, const NVTETensor &out_per_step,
......@@ -712,9 +712,9 @@ void nvte_cp_thd_out_correction(NVTETensor out, const NVTETensor &out_per_step,
using namespace transformer_engine;
context_parallel::thd_out_correction(
*reinterpret_cast<Tensor *>(out), *reinterpret_cast<Tensor *>(out_per_step),
*reinterpret_cast<Tensor *>(lse), *reinterpret_cast<Tensor *>(lse_per_step),
*reinterpret_cast<Tensor *>(cu_seqlens), only_second_half, lse_packed, stream);
*convertNVTETensorCheck(out), *convertNVTETensorCheck(out_per_step),
*convertNVTETensorCheck(lse), *convertNVTETensorCheck(lse_per_step),
*convertNVTETensorCheck(cu_seqlens), only_second_half, lse_packed, stream);
}
void nvte_cp_thd_grad_correction(NVTETensor grad, const NVTETensor &grad_per_step,
......@@ -727,8 +727,8 @@ void nvte_cp_thd_grad_correction(NVTETensor grad, const NVTETensor &grad_per_ste
std::string second_half_str(second_half);
context_parallel::thd_grad_correction(
*reinterpret_cast<Tensor *>(grad), *reinterpret_cast<Tensor *>(grad_per_step),
*reinterpret_cast<Tensor *>(cu_seqlens), first_half_str, second_half_str, stream);
*convertNVTETensorCheck(grad), *convertNVTETensorCheck(grad_per_step),
*convertNVTETensorCheck(cu_seqlens), first_half_str, second_half_str, stream);
}
void nvte_cp_thd_get_partitioned_indices(const NVTETensor &cu_seqlens, NVTETensor output,
......@@ -737,7 +737,7 @@ void nvte_cp_thd_get_partitioned_indices(const NVTETensor &cu_seqlens, NVTETenso
NVTE_API_CALL(nvte_thd_get_partitioned_indices);
using namespace transformer_engine;
context_parallel::thd_get_partitioned_indices(*reinterpret_cast<Tensor *>(cu_seqlens),
*reinterpret_cast<Tensor *>(output), total_tokens,
context_parallel::thd_get_partitioned_indices(*convertNVTETensorCheck(cu_seqlens),
*convertNVTETensorCheck(output), total_tokens,
world_size, rank, stream);
}
......@@ -138,8 +138,8 @@ void nvte_prepare_flash_attn_fwd(NVTETensor qkvi, NVTETensor qkv, cudaStream_t s
NVTE_API_CALL(nvte_prepare_flash_attn_fwd);
using namespace transformer_engine;
flash_attention::prepare_flash_attn_fwd(*reinterpret_cast<Tensor *>(qkvi),
*reinterpret_cast<Tensor *>(qkv), stream);
flash_attention::prepare_flash_attn_fwd(*convertNVTETensorCheck(qkvi),
*convertNVTETensorCheck(qkv), stream);
}
void nvte_prepare_flash_attn_bwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTETensor qkv,
......@@ -147,7 +147,7 @@ void nvte_prepare_flash_attn_bwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTET
NVTE_API_CALL(nvte_prepare_flash_attn_bwd);
using namespace transformer_engine;
flash_attention::prepare_flash_attn_bwd(
*reinterpret_cast<Tensor *>(q), *reinterpret_cast<Tensor *>(k),
*reinterpret_cast<Tensor *>(v), *reinterpret_cast<Tensor *>(qkv), stream);
flash_attention::prepare_flash_attn_bwd(*convertNVTETensorCheck(q), *convertNVTETensorCheck(k),
*convertNVTETensorCheck(v), *convertNVTETensorCheck(qkv),
stream);
}
......@@ -392,14 +392,14 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked);
using namespace transformer_engine;
const Tensor *input_cu_seqlens = reinterpret_cast<const Tensor *>(cu_seqlens);
const Tensor *input_cu_seqlens_padded = reinterpret_cast<const Tensor *>(cu_seqlens_padded);
const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(rng_state);
const Tensor *input_QKV = reinterpret_cast<const Tensor *>(QKV);
const Tensor *input_Bias = reinterpret_cast<const Tensor *>(Bias);
Tensor *input_output_S = reinterpret_cast<Tensor *>(S);
Tensor *output_O = reinterpret_cast<Tensor *>(O);
Tensor *wkspace = reinterpret_cast<Tensor *>(workspace);
const Tensor *input_cu_seqlens = convertNVTETensorCheck(cu_seqlens);
const Tensor *input_cu_seqlens_padded = convertNVTETensorCheck(cu_seqlens_padded);
const Tensor *input_rng_state = convertNVTETensorCheck(rng_state);
const Tensor *input_QKV = convertNVTETensorCheck(QKV);
const Tensor *input_Bias = convertNVTETensorCheck(Bias);
Tensor *input_output_S = convertNVTETensorCheck(S);
Tensor *output_O = convertNVTETensorCheck(O);
Tensor *wkspace = convertNVTETensor(workspace);
auto ndim = input_QKV->data.shape.size();
size_t b = input_cu_seqlens->data.shape[0] - 1;
......@@ -472,16 +472,16 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked);
using namespace transformer_engine;
const Tensor *input_cu_seqlens = reinterpret_cast<const Tensor *>(cu_seqlens);
const Tensor *input_cu_seqlens_padded = reinterpret_cast<const Tensor *>(cu_seqlens_padded);
const Tensor *input_QKV = reinterpret_cast<const Tensor *>(QKV);
const Tensor *input_O = reinterpret_cast<const Tensor *>(O);
const Tensor *input_dO = reinterpret_cast<const Tensor *>(dO);
const Tensor *input_S = reinterpret_cast<const Tensor *>(S);
Tensor *input_output_dP = reinterpret_cast<Tensor *>(dP);
Tensor *output_dQKV = reinterpret_cast<Tensor *>(dQKV);
Tensor *output_dBias = reinterpret_cast<Tensor *>(dBias);
Tensor *wkspace = reinterpret_cast<Tensor *>(workspace);
const Tensor *input_cu_seqlens = convertNVTETensorCheck(cu_seqlens);
const Tensor *input_cu_seqlens_padded = convertNVTETensorCheck(cu_seqlens_padded);
const Tensor *input_QKV = convertNVTETensorCheck(QKV);
const Tensor *input_O = convertNVTETensorCheck(O);
const Tensor *input_dO = convertNVTETensorCheck(dO);
const Tensor *input_S = convertNVTETensorCheck(S);
Tensor *input_output_dP = convertNVTETensorCheck(dP);
Tensor *output_dQKV = convertNVTETensorCheck(dQKV);
Tensor *output_dBias = convertNVTETensorCheck(dBias);
Tensor *wkspace = convertNVTETensor(workspace);
auto ndim = input_QKV->data.shape.size();
size_t b = input_cu_seqlens->data.shape[0] - 1;
......@@ -510,7 +510,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
fused_attn_max_512_bwd_qkvpacked(
b, h, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_QKV,
input_dO, output_S, output_dQKV, output_dBias, input_cu_seqlens, wkspace, stream, handle);
......@@ -519,13 +519,13 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900)
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Tensor *input_Bias, *input_rng_state;
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
input_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
input_Bias = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[2]);
input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
} else {
input_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
}
fused_attn_arbitrary_seqlen_bwd_qkvpacked(
b, h, max_seqlen, d, t, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
......@@ -540,9 +540,9 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
#if (CUDNN_VERSION >= 8900)
const Tensor *input_M = reinterpret_cast<const Tensor *>(Aux_CTX_Tensors->tensors[0]);
const Tensor *input_ZInv = reinterpret_cast<const Tensor *>(Aux_CTX_Tensors->tensors[1]);
const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(Aux_CTX_Tensors->tensors[2]);
const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
fused_attn_fp8_bwd_qkvpacked(b, h, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type,
attn_mask_type, input_QKV, input_O, input_dO, input_M, input_ZInv,
input_S, input_output_dP, output_dQKV, input_cu_seqlens,
......@@ -566,19 +566,19 @@ void nvte_fused_attn_fwd_kvpacked(
cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor *>(cu_seqlens_kv);
const Tensor *input_cu_seqlens_q_padded = reinterpret_cast<const Tensor *>(cu_seqlens_q_padded);
const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast<const Tensor *>(cu_seqlens_kv_padded);
const Tensor *input_page_table_k = reinterpret_cast<const Tensor *>(page_table_k);
const Tensor *input_page_table_v = reinterpret_cast<const Tensor *>(page_table_v);
const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(rng_state);
const Tensor *input_Q = reinterpret_cast<const Tensor *>(Q);
const Tensor *input_KV = reinterpret_cast<const Tensor *>(KV);
const Tensor *input_Bias = reinterpret_cast<const Tensor *>(Bias);
Tensor *input_output_S = reinterpret_cast<Tensor *>(S);
Tensor *output_O = reinterpret_cast<Tensor *>(O);
Tensor *wkspace = reinterpret_cast<Tensor *>(workspace);
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv);
const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded);
const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded);
const Tensor *input_page_table_k = convertNVTETensorCheck(page_table_k);
const Tensor *input_page_table_v = convertNVTETensorCheck(page_table_v);
const Tensor *input_rng_state = convertNVTETensorCheck(rng_state);
const Tensor *input_Q = convertNVTETensorCheck(Q);
const Tensor *input_KV = convertNVTETensorCheck(KV);
const Tensor *input_Bias = convertNVTETensorCheck(Bias);
Tensor *input_output_S = convertNVTETensorCheck(S);
Tensor *output_O = convertNVTETensorCheck(O);
Tensor *wkspace = convertNVTETensor(workspace);
size_t b = input_cu_seqlens_q->data.shape[0] - 1;
auto ndim = input_Q->data.shape.size();
......@@ -686,20 +686,20 @@ void nvte_fused_attn_bwd_kvpacked(
cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor *>(cu_seqlens_kv);
const Tensor *input_cu_seqlens_q_padded = reinterpret_cast<const Tensor *>(cu_seqlens_q_padded);
const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast<const Tensor *>(cu_seqlens_kv_padded);
const Tensor *input_Q = reinterpret_cast<const Tensor *>(Q);
const Tensor *input_KV = reinterpret_cast<const Tensor *>(KV);
const Tensor *input_O = reinterpret_cast<const Tensor *>(O);
const Tensor *input_dO = reinterpret_cast<const Tensor *>(dO);
const Tensor *input_S = reinterpret_cast<const Tensor *>(S);
Tensor *input_output_dP = reinterpret_cast<Tensor *>(dP);
Tensor *output_dQ = reinterpret_cast<Tensor *>(dQ);
Tensor *output_dKV = reinterpret_cast<Tensor *>(dKV);
Tensor *output_dBias = reinterpret_cast<Tensor *>(dBias);
Tensor *wkspace = reinterpret_cast<Tensor *>(workspace);
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv);
const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded);
const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded);
const Tensor *input_Q = convertNVTETensorCheck(Q);
const Tensor *input_KV = convertNVTETensorCheck(KV);
const Tensor *input_O = convertNVTETensorCheck(O);
const Tensor *input_dO = convertNVTETensorCheck(dO);
const Tensor *input_S = convertNVTETensorCheck(S);
Tensor *input_output_dP = convertNVTETensorCheck(dP);
Tensor *output_dQ = convertNVTETensorCheck(dQ);
Tensor *output_dKV = convertNVTETensorCheck(dKV);
Tensor *output_dBias = convertNVTETensorCheck(dBias);
Tensor *wkspace = convertNVTETensor(workspace);
size_t b = input_cu_seqlens_q->data.shape[0] - 1;
auto ndim = input_Q->data.shape.size();
......@@ -736,7 +736,7 @@ void nvte_fused_attn_bwd_kvpacked(
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
fused_attn_max_512_bwd_kvpacked(
b, h_q, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type,
attn_mask_type, input_Q, input_KV, input_dO, output_S, output_dQ, output_dKV, output_dBias,
......@@ -746,13 +746,13 @@ void nvte_fused_attn_bwd_kvpacked(
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8903)
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Tensor *input_Bias, *input_rng_state;
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
input_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
input_Bias = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[2]);
input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
} else {
input_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
}
fused_attn_arbitrary_seqlen_bwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, attn_scale, dropout, qkv_layout,
......@@ -768,9 +768,9 @@ void nvte_fused_attn_bwd_kvpacked(
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
#if (CUDNN_VERSION >= 8900)
const Tensor *input_M = reinterpret_cast<const Tensor *>(Aux_CTX_Tensors->tensors[0]);
const Tensor *input_ZInv = reinterpret_cast<const Tensor *>(Aux_CTX_Tensors->tensors[1]);
const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(Aux_CTX_Tensors->tensors[2]);
const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
fused_attn_fp8_bwd_kvpacked(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout,
qkv_layout, bias_type, attn_mask_type, input_Q, input_KV, input_O,
input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ,
......@@ -797,20 +797,20 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor *>(cu_seqlens_kv);
const Tensor *input_cu_seqlens_q_padded = reinterpret_cast<const Tensor *>(cu_seqlens_q_padded);
const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast<const Tensor *>(cu_seqlens_kv_padded);
const Tensor *input_page_table_k = reinterpret_cast<const Tensor *>(page_table_k);
const Tensor *input_page_table_v = reinterpret_cast<const Tensor *>(page_table_v);
const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(rng_state);
const Tensor *input_Q = reinterpret_cast<const Tensor *>(Q);
const Tensor *input_K = reinterpret_cast<const Tensor *>(K);
const Tensor *input_V = reinterpret_cast<const Tensor *>(V);
const Tensor *input_Bias = reinterpret_cast<const Tensor *>(Bias);
Tensor *input_output_S = reinterpret_cast<Tensor *>(S);
Tensor *output_O = reinterpret_cast<Tensor *>(O);
Tensor *wkspace = reinterpret_cast<Tensor *>(workspace);
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv);
const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded);
const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded);
const Tensor *input_page_table_k = convertNVTETensorCheck(page_table_k);
const Tensor *input_page_table_v = convertNVTETensorCheck(page_table_v);
const Tensor *input_rng_state = convertNVTETensorCheck(rng_state);
const Tensor *input_Q = convertNVTETensorCheck(Q);
const Tensor *input_K = convertNVTETensorCheck(K);
const Tensor *input_V = convertNVTETensorCheck(V);
const Tensor *input_Bias = convertNVTETensorCheck(Bias);
Tensor *input_output_S = convertNVTETensorCheck(S);
Tensor *output_O = convertNVTETensorCheck(O);
Tensor *wkspace = convertNVTETensor(workspace);
auto ndim = input_Q->data.shape.size();
auto ndim_kv = input_K->data.shape.size();
......@@ -914,22 +914,22 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = reinterpret_cast<const Tensor *>(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor *>(cu_seqlens_kv);
const Tensor *input_cu_seqlens_q_padded = reinterpret_cast<const Tensor *>(cu_seqlens_q_padded);
const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast<const Tensor *>(cu_seqlens_kv_padded);
const Tensor *input_Q = reinterpret_cast<const Tensor *>(Q);
const Tensor *input_K = reinterpret_cast<const Tensor *>(K);
const Tensor *input_V = reinterpret_cast<const Tensor *>(V);
const Tensor *input_O = reinterpret_cast<const Tensor *>(O);
const Tensor *input_dO = reinterpret_cast<const Tensor *>(dO);
const Tensor *input_S = reinterpret_cast<const Tensor *>(S);
Tensor *input_output_dP = reinterpret_cast<Tensor *>(dP);
Tensor *output_dQ = reinterpret_cast<Tensor *>(dQ);
Tensor *output_dK = reinterpret_cast<Tensor *>(dK);
Tensor *output_dV = reinterpret_cast<Tensor *>(dV);
Tensor *output_dBias = reinterpret_cast<Tensor *>(dBias);
Tensor *wkspace = reinterpret_cast<Tensor *>(workspace);
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv);
const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded);
const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded);
const Tensor *input_Q = convertNVTETensorCheck(Q);
const Tensor *input_K = convertNVTETensorCheck(K);
const Tensor *input_V = convertNVTETensorCheck(V);
const Tensor *input_O = convertNVTETensorCheck(O);
const Tensor *input_dO = convertNVTETensorCheck(dO);
const Tensor *input_S = convertNVTETensorCheck(S);
Tensor *input_output_dP = convertNVTETensorCheck(dP);
Tensor *output_dQ = convertNVTETensorCheck(dQ);
Tensor *output_dK = convertNVTETensorCheck(dK);
Tensor *output_dV = convertNVTETensorCheck(dV);
Tensor *output_dBias = convertNVTETensorCheck(dBias);
Tensor *wkspace = convertNVTETensor(workspace);
auto ndim = input_Q->data.shape.size();
auto ndim_kv = input_K->data.shape.size();
......@@ -959,7 +959,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
fused_attn_max_512_bwd(b, h_q, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout,
qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V,
input_dO, output_S, output_dQ, output_dK, output_dV, output_dBias,
......@@ -969,13 +969,13 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900)
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Tensor *input_Bias, *input_rng_state;
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
input_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
input_Bias = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[2]);
input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
} else {
input_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
}
fused_attn_arbitrary_seqlen_bwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, attn_scale, dropout,
......@@ -991,9 +991,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
#if (CUDNN_VERSION >= 8900)
const Tensor *input_M = reinterpret_cast<const Tensor *>(Aux_CTX_Tensors->tensors[0]);
const Tensor *input_ZInv = reinterpret_cast<const Tensor *>(Aux_CTX_Tensors->tensors[1]);
const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(Aux_CTX_Tensors->tensors[2]);
const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout,
qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, input_O,
input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ,
......
......@@ -990,7 +990,7 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
const auto cudnn_runtime_version = cudnnGetVersion();
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Aux_CTX_Tensors->size = 3;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens, num_attn_heads, 1};
......@@ -998,17 +998,17 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1};
}
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
Tensor *output_bias = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[2]);
Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
output_bias->data.dptr = nullptr;
output_bias->data.shape = {bias_b, bias_h, max_seqlen, max_seqlen};
output_bias->data.dtype = QKV_type;
} else {
Aux_CTX_Tensors->size = 2;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens, num_attn_heads, 1};
......@@ -1016,22 +1016,22 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1};
}
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
}
} else if (Aux_CTX_Tensors->size == 2) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = rng_state->data.dptr;
} else if (Aux_CTX_Tensors->size == 3) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = rng_state->data.dptr;
Tensor *output_bias = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[2]);
Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
output_bias->data.dptr = devPtrBias;
} else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
......@@ -1216,7 +1216,7 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
const auto cudnn_runtime_version = cudnnGetVersion();
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Aux_CTX_Tensors->size = 3;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens_q, num_attn_heads, 1};
......@@ -1224,17 +1224,17 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
}
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
Tensor *output_bias = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[2]);
Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
output_bias->data.dptr = nullptr;
output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv};
output_bias->data.dtype = QKV_type;
} else {
Aux_CTX_Tensors->size = 2;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens_q, num_attn_heads, 1};
......@@ -1242,22 +1242,22 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
}
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
}
} else if (Aux_CTX_Tensors->size == 2) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = rng_state->data.dptr;
} else if (Aux_CTX_Tensors->size == 3) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = rng_state->data.dptr;
Tensor *output_bias = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[2]);
Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
output_bias->data.dptr = devPtrBias;
} else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
......@@ -1446,7 +1446,7 @@ void fused_attn_arbitrary_seqlen_fwd(
const auto cudnn_runtime_version = cudnnGetVersion();
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Aux_CTX_Tensors->size = 3;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens_q, num_attn_heads, 1};
......@@ -1454,17 +1454,17 @@ void fused_attn_arbitrary_seqlen_fwd(
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
}
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
Tensor *output_bias = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[2]);
Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
output_bias->data.dptr = nullptr;
output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv};
output_bias->data.dtype = QKV_type;
} else {
Aux_CTX_Tensors->size = 2;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {max_tokens_q, num_attn_heads, 1};
......@@ -1472,22 +1472,22 @@ void fused_attn_arbitrary_seqlen_fwd(
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
}
output_S->data.dtype = DType::kFloat32;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = nullptr;
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
}
} else if (Aux_CTX_Tensors->size == 2) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = rng_state->data.dptr;
} else if (Aux_CTX_Tensors->size == 3) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
Tensor *output_rng_state = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[1]);
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
output_rng_state->data.dptr = rng_state->data.dptr;
Tensor *output_bias = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[2]);
Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
output_bias->data.dptr = devPtrBias;
} else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
......
......@@ -1239,12 +1239,12 @@ void fused_attn_max_512_fwd_qkvpacked(
if (Aux_CTX_Tensors->size == 0) {
Aux_CTX_Tensors->size = 1;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
output_S->data.shape = {batch, num_head, max_seqlen, max_seqlen};
output_S->data.dtype = input_QKV->data.dtype;
} else if (Aux_CTX_Tensors->size == 1) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
} else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
......@@ -1317,12 +1317,12 @@ void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t num_head, size_t q_max
if (Aux_CTX_Tensors->size == 0) {
Aux_CTX_Tensors->size = 1;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
output_S->data.shape = {batch, num_head, q_max_seqlen, kv_max_seqlen};
output_S->data.dtype = q_type;
} else if (Aux_CTX_Tensors->size == 1) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
} else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
......@@ -1386,12 +1386,12 @@ void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen,
if (Aux_CTX_Tensors->size == 0) {
Aux_CTX_Tensors->size = 1;
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
output_S->data.dptr = nullptr;
output_S->data.shape = {batch, num_head, q_max_seqlen, kv_max_seqlen};
output_S->data.dtype = q_type;
} else if (Aux_CTX_Tensors->size == 1) {
Tensor *output_S = reinterpret_cast<Tensor *>(Aux_CTX_Tensors->tensors[0]);
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
devPtrS = output_S->data.dptr;
} else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
......
......@@ -2383,9 +2383,9 @@ void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t ma
void* devPtrZInv = nullptr;
if (Aux_CTX_Tensors->size == 0) {
Aux_CTX_Tensors->size = 3;
Tensor* output_M = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[0]);
Tensor* output_ZInv = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]);
Tensor* output_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[2]);
Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
output_M->data.dptr = nullptr;
output_M->data.shape = {batch, num_attn_heads, max_seqlen, 1};
output_M->data.dtype = DType::kFloat32;
......@@ -2396,9 +2396,9 @@ void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t ma
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
} else if (Aux_CTX_Tensors->size == 3) {
Tensor* output_M = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[0]);
Tensor* output_ZInv = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]);
Tensor* output_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[2]);
Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
devPtrM = output_M->data.dptr;
devPtrZInv = output_ZInv->data.dptr;
output_rng_state->data.dptr = rng_state->data.dptr;
......@@ -2582,9 +2582,9 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num
void* devPtrZInv = nullptr;
if (Aux_CTX_Tensors->size == 0) {
Aux_CTX_Tensors->size = 3;
Tensor* output_M = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[0]);
Tensor* output_ZInv = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]);
Tensor* output_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[2]);
Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
output_M->data.dptr = nullptr;
output_M->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
output_M->data.dtype = DType::kFloat32;
......@@ -2595,9 +2595,9 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
} else if (Aux_CTX_Tensors->size == 3) {
Tensor* output_M = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[0]);
Tensor* output_ZInv = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]);
Tensor* output_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[2]);
Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
devPtrM = output_M->data.dptr;
devPtrZInv = output_ZInv->data.dptr;
output_rng_state->data.dptr = rng_state->data.dptr;
......@@ -2779,9 +2779,9 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou
void* devPtrZInv = nullptr;
if (Aux_CTX_Tensors->size == 0) {
Aux_CTX_Tensors->size = 3;
Tensor* output_M = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[0]);
Tensor* output_ZInv = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]);
Tensor* output_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[2]);
Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
output_M->data.dptr = nullptr;
output_M->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
output_M->data.dtype = DType::kFloat32;
......@@ -2792,9 +2792,9 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou
output_rng_state->data.shape = {2};
output_rng_state->data.dtype = DType::kInt64;
} else if (Aux_CTX_Tensors->size == 3) {
Tensor* output_M = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[0]);
Tensor* output_ZInv = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[1]);
Tensor* output_rng_state = reinterpret_cast<Tensor*>(Aux_CTX_Tensors->tensors[2]);
Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
devPtrM = output_M->data.dptr;
devPtrZInv = output_ZInv->data.dptr;
output_rng_state->data.dptr = rng_state->data.dptr;
......
......@@ -260,12 +260,12 @@ void nvte_copy_to_kv_cache(NVTETensor new_k, NVTETensor new_v, NVTETensor k_cach
NVTE_API_CALL(nvte_copy_to_kv_cache);
using namespace transformer_engine;
kv_cache::copy_to_kv_cache(
*reinterpret_cast<Tensor *>(new_k), *reinterpret_cast<Tensor *>(new_v),
*reinterpret_cast<Tensor *>(k_cache), *reinterpret_cast<Tensor *>(v_cache),
*reinterpret_cast<Tensor *>(page_table), *reinterpret_cast<Tensor *>(cu_new_lens),
*reinterpret_cast<Tensor *>(cu_cached_lens), qkv_format, b, max_ctx_len, max_seq_len,
max_pages_per_seq, is_non_paged, stream);
kv_cache::copy_to_kv_cache(*convertNVTETensorCheck(new_k), *convertNVTETensorCheck(new_v),
*convertNVTETensorCheck(k_cache), *convertNVTETensorCheck(v_cache),
*convertNVTETensorCheck(page_table),
*convertNVTETensorCheck(cu_new_lens),
*convertNVTETensorCheck(cu_cached_lens), qkv_format, b, max_ctx_len,
max_seq_len, max_pages_per_seq, is_non_paged, stream);
}
/***************************************************************************************************
......@@ -277,9 +277,9 @@ void nvte_convert_thd_to_bshd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETens
NVTE_API_CALL(nvte_convert_thd_to_bshd);
using namespace transformer_engine;
kv_cache::convert_thd_to_bshd(*reinterpret_cast<Tensor *>(tensor),
*reinterpret_cast<Tensor *>(cu_seqlens),
*reinterpret_cast<Tensor *>(new_tensor), b, max_seq_len, stream);
kv_cache::convert_thd_to_bshd(*convertNVTETensorCheck(tensor),
*convertNVTETensorCheck(cu_seqlens),
*convertNVTETensorCheck(new_tensor), b, max_seq_len, stream);
}
/***************************************************************************************************
......@@ -291,7 +291,7 @@ void nvte_convert_bshd_to_thd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETens
NVTE_API_CALL(nvte_convert_bshd_to_thd);
using namespace transformer_engine;
kv_cache::convert_bshd_to_thd(*reinterpret_cast<Tensor *>(tensor),
*reinterpret_cast<Tensor *>(cu_seqlens),
*reinterpret_cast<Tensor *>(new_tensor), t, stream);
kv_cache::convert_bshd_to_thd(*convertNVTETensorCheck(tensor),
*convertNVTETensorCheck(cu_seqlens),
*convertNVTETensorCheck(new_tensor), t, stream);
}
......@@ -308,11 +308,10 @@ void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens
const int stride_d, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_rope_forward);
using namespace transformer_engine;
fused_rope_forward(
*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(cu_seqlens),
*reinterpret_cast<const Tensor *>(freqs), *reinterpret_cast<const Tensor *>(start_positions),
reinterpret_cast<Tensor *>(output), qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2,
stride_s_or_t, stride_b, stride_h, stride_d, stream);
fused_rope_forward(*convertNVTETensorCheck(input), *convertNVTETensorCheck(cu_seqlens),
*convertNVTETensorCheck(freqs), *convertNVTETensorCheck(start_positions),
convertNVTETensorCheck(output), qkv_format, interleaved, cp_size, cp_rank, s,
b, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, stream);
}
void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens,
......@@ -324,9 +323,8 @@ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu
cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_rope_backward);
using namespace transformer_engine;
fused_rope_backward(*reinterpret_cast<const Tensor *>(output_grads),
*reinterpret_cast<const Tensor *>(cu_seqlens),
*reinterpret_cast<const Tensor *>(freqs),
reinterpret_cast<Tensor *>(input_grads), qkv_format, interleaved, cp_size,
cp_rank, s, b, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, stream);
fused_rope_backward(*convertNVTETensorCheck(output_grads), *convertNVTETensorCheck(cu_seqlens),
*convertNVTETensorCheck(freqs), convertNVTETensorCheck(input_grads),
qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t,
stride_b, stride_h, stride_d, stream);
}
......@@ -551,8 +551,8 @@ void nvte_scaled_aligned_causal_masked_softmax_forward(const NVTETensor input,
float scale_factor, cudaStream_t stream) {
NVTE_API_CALL(nvte_scaled_aligned_causal_masked_softmax_forward);
using namespace transformer_engine;
scaled_aligned_causal_masked_softmax_forward(*reinterpret_cast<const Tensor *>(input),
reinterpret_cast<Tensor *>(softmax_results),
scaled_aligned_causal_masked_softmax_forward(*convertNVTETensorCheck(input),
convertNVTETensorCheck(softmax_results),
scale_factor, stream);
}
......@@ -563,6 +563,6 @@ void nvte_scaled_aligned_causal_masked_softmax_backward(const NVTETensor incomin
NVTE_API_CALL(nvte_scaled_aligned_causal_masked_softmax_backward);
using namespace transformer_engine;
scaled_aligned_causal_masked_softmax_backward(
*reinterpret_cast<Tensor *>(output_grads), *reinterpret_cast<const Tensor *>(incoming_grads),
*reinterpret_cast<const Tensor *>(softmax_results), scale_factor, stream);
*convertNVTETensorCheck(output_grads), *convertNVTETensorCheck(incoming_grads),
*convertNVTETensorCheck(softmax_results), scale_factor, stream);
}
......@@ -815,8 +815,8 @@ void nvte_scaled_softmax_forward(const NVTETensor input, NVTETensor softmax_resu
float scale_factor, cudaStream_t stream) {
NVTE_API_CALL(nvte_scaled_softmax_forward);
using namespace transformer_engine;
scaled_softmax_forward(*reinterpret_cast<const Tensor *>(input),
reinterpret_cast<Tensor *>(softmax_results), scale_factor, stream);
scaled_softmax_forward(*convertNVTETensorCheck(input), convertNVTETensorCheck(softmax_results),
scale_factor, stream);
}
void nvte_scaled_softmax_backward(const NVTETensor incoming_grads, const NVTETensor softmax_results,
......@@ -824,9 +824,9 @@ void nvte_scaled_softmax_backward(const NVTETensor incoming_grads, const NVTETen
cudaStream_t stream) {
NVTE_API_CALL(nvte_scaled_softmax_backward);
using namespace transformer_engine;
scaled_softmax_backward(*reinterpret_cast<Tensor *>(output_grads),
*reinterpret_cast<const Tensor *>(incoming_grads),
*reinterpret_cast<const Tensor *>(softmax_results), scale_factor, stream);
scaled_softmax_backward(*convertNVTETensorCheck(output_grads),
*convertNVTETensorCheck(incoming_grads),
*convertNVTETensorCheck(softmax_results), scale_factor, stream);
}
void nvte_scaled_masked_softmax_forward(const NVTETensor input, const NVTETensor mask,
......@@ -834,9 +834,8 @@ void nvte_scaled_masked_softmax_forward(const NVTETensor input, const NVTETensor
cudaStream_t stream) {
NVTE_API_CALL(nvte_scaled_masked_softmax_forward);
using namespace transformer_engine;
scaled_masked_softmax_forward(*reinterpret_cast<const Tensor *>(input),
*reinterpret_cast<const Tensor *>(mask),
reinterpret_cast<Tensor *>(softmax_results), scale_factor, stream);
scaled_masked_softmax_forward(*convertNVTETensorCheck(input), *convertNVTETensorCheck(mask),
convertNVTETensorCheck(softmax_results), scale_factor, stream);
}
void nvte_scaled_masked_softmax_backward(const NVTETensor incoming_grads,
......@@ -844,7 +843,7 @@ void nvte_scaled_masked_softmax_backward(const NVTETensor incoming_grads,
float scale_factor, cudaStream_t stream) {
NVTE_API_CALL(nvte_scaled_masked_softmax_backward);
using namespace transformer_engine;
scaled_masked_softmax_backward(
*reinterpret_cast<Tensor *>(output_grads), *reinterpret_cast<const Tensor *>(incoming_grads),
*reinterpret_cast<const Tensor *>(softmax_results), scale_factor, stream);
scaled_masked_softmax_backward(*convertNVTETensorCheck(output_grads),
*convertNVTETensorCheck(incoming_grads),
*convertNVTETensorCheck(softmax_results), scale_factor, stream);
}
......@@ -599,9 +599,9 @@ void nvte_scaled_upper_triang_masked_softmax_forward(const NVTETensor input,
NVTETensor softmax_results, float scale_factor,
cudaStream_t stream) {
using namespace transformer_engine;
scaled_upper_triang_masked_softmax_forward(*reinterpret_cast<const Tensor *>(input),
reinterpret_cast<Tensor *>(softmax_results),
scale_factor, stream);
scaled_upper_triang_masked_softmax_forward(*convertNVTETensorCheck(input),
convertNVTETensorCheck(softmax_results), scale_factor,
stream);
}
void nvte_scaled_upper_triang_masked_softmax_backward(const NVTETensor incoming_grads,
......@@ -610,6 +610,6 @@ void nvte_scaled_upper_triang_masked_softmax_backward(const NVTETensor incoming_
cudaStream_t stream) {
using namespace transformer_engine;
scaled_upper_triang_masked_softmax_backward(
*reinterpret_cast<Tensor *>(output_grads), *reinterpret_cast<const Tensor *>(incoming_grads),
*reinterpret_cast<const Tensor *>(softmax_results), scale_factor, stream);
*convertNVTETensorCheck(output_grads), *convertNVTETensorCheck(incoming_grads),
*convertNVTETensorCheck(softmax_results), scale_factor, stream);
}
......@@ -588,12 +588,12 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
int math_sm_count, cudaStream_t stream) {
NVTE_API_CALL(nvte_cublas_gemm);
using namespace transformer_engine;
const Tensor *inputA = reinterpret_cast<const Tensor *>(A);
const Tensor *inputB = reinterpret_cast<const Tensor *>(B);
Tensor *outputD = reinterpret_cast<Tensor *>(D);
const Tensor *biasTensor = reinterpret_cast<const Tensor *>(bias);
Tensor *outputGelu = reinterpret_cast<Tensor *>(pre_gelu_out);
Tensor *wspace = reinterpret_cast<Tensor *>(workspace);
const Tensor *inputA = convertNVTETensorCheck(A);
const Tensor *inputB = convertNVTETensorCheck(B);
Tensor *outputD = convertNVTETensor(D);
const Tensor *biasTensor = convertNVTETensor(bias);
Tensor *outputGelu = convertNVTETensor(pre_gelu_out);
Tensor *wspace = convertNVTETensor(workspace);
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
(transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0],
......@@ -616,13 +616,13 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
"Cublas version >=12.2.5 and <13.0 is required for atomic gemm.");
using namespace transformer_engine;
const Tensor *inputA = reinterpret_cast<const Tensor *>(A);
const Tensor *inputB = reinterpret_cast<const Tensor *>(B);
Tensor *outputD = reinterpret_cast<Tensor *>(D);
const Tensor *biasTensor = reinterpret_cast<const Tensor *>(bias);
Tensor *outputGelu = reinterpret_cast<Tensor *>(pre_gelu_out);
const Tensor *inputCounter = reinterpret_cast<const Tensor *>(counter);
Tensor *wspace = reinterpret_cast<Tensor *>(workspace);
const Tensor *inputA = convertNVTETensorCheck(A);
const Tensor *inputB = convertNVTETensorCheck(B);
Tensor *outputD = convertNVTETensor(D);
const Tensor *biasTensor = convertNVTETensor(bias);
Tensor *outputGelu = convertNVTETensor(pre_gelu_out);
const Tensor *inputCounter = convertNVTETensor(counter);
Tensor *wspace = convertNVTETensor(workspace);
NVTE_CHECK(is_delayed_tensor_scaling(inputA->scaling_mode) &&
is_delayed_tensor_scaling(inputB->scaling_mode),
......
......@@ -806,7 +806,7 @@ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETenso
using namespace transformer_engine;
multi_tensor_adam::multi_tensor_adam_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2,
epsilon, step, mode, bias_correction, weight_decay, device_id, stream);
}
......@@ -820,7 +820,7 @@ void nvte_multi_tensor_adam_param_remainder_cuda(
using namespace transformer_engine;
multi_tensor_adam::multi_tensor_adam_param_remainder_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2,
epsilon, step, mode, bias_correction, weight_decay, device_id, stream);
}
......@@ -836,7 +836,7 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag,
using namespace transformer_engine;
multi_tensor_adam::multi_tensor_adam_fp8_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2,
epsilon, step, mode, bias_correction, weight_decay, static_cast<DType>(fp8_dtype), device_id,
stream);
......@@ -851,11 +851,10 @@ void nvte_multi_tensor_adam_capturable_cuda(
using namespace transformer_engine;
multi_tensor_adam::multi_tensor_adam_capturable_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
*reinterpret_cast<Tensor *>(lr), beta1, beta2, epsilon, *reinterpret_cast<Tensor *>(step),
mode, bias_correction, weight_decay, *reinterpret_cast<Tensor *>(inv_scale), device_id,
stream);
*convertNVTETensorCheck(lr), beta1, beta2, epsilon, *convertNVTETensorCheck(step), mode,
bias_correction, weight_decay, *convertNVTETensorCheck(inv_scale), device_id, stream);
}
void nvte_multi_tensor_adam_capturable_master_cuda(
......@@ -867,9 +866,8 @@ void nvte_multi_tensor_adam_capturable_master_cuda(
using namespace transformer_engine;
multi_tensor_adam::multi_tensor_adam_capturable_master_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
*reinterpret_cast<Tensor *>(lr), beta1, beta2, epsilon, *reinterpret_cast<Tensor *>(step),
mode, bias_correction, weight_decay, *reinterpret_cast<Tensor *>(inv_scale), device_id,
stream);
*convertNVTETensorCheck(lr), beta1, beta2, epsilon, *convertNVTETensorCheck(step), mode,
bias_correction, weight_decay, *convertNVTETensorCheck(inv_scale), device_id, stream);
}
......@@ -77,7 +77,7 @@ void nvte_multi_tensor_compute_scale_and_scale_inv_cuda(
using namespace transformer_engine;
multi_tensor_compute_scale::multi_tensor_compute_scale_and_scale_inv_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), max_fp8,
force_pow_2_scales, epsilon, device_id, stream);
}
......@@ -459,10 +459,10 @@ void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETen
using namespace transformer_engine;
multi_tensor_l2norm::multi_tensor_l2norm_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
*reinterpret_cast<Tensor *>(output), *reinterpret_cast<Tensor *>(output_per_tensor),
*reinterpret_cast<Tensor *>(ret), *reinterpret_cast<Tensor *>(ret_per_tensor), per_tensor,
*convertNVTETensorCheck(output), *convertNVTETensorCheck(output_per_tensor),
*convertNVTETensorCheck(ret), *convertNVTETensorCheck(ret_per_tensor), per_tensor,
max_chunks_per_tensor, device_id, stream);
}
......@@ -477,9 +477,9 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
using namespace transformer_engine;
multi_tensor_l2norm::multi_tensor_unscale_l2norm_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
*reinterpret_cast<Tensor *>(output), *reinterpret_cast<Tensor *>(output_per_tensor),
*reinterpret_cast<Tensor *>(ret), *reinterpret_cast<Tensor *>(ret_per_tensor),
*reinterpret_cast<Tensor *>(inv_scale), per_tensor, max_chunks_per_tensor, device_id, stream);
*convertNVTETensorCheck(output), *convertNVTETensorCheck(output_per_tensor),
*convertNVTETensorCheck(ret), *convertNVTETensorCheck(ret_per_tensor),
*convertNVTETensorCheck(inv_scale), per_tensor, max_chunks_per_tensor, device_id, stream);
}
......@@ -124,7 +124,7 @@ void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETens
using namespace transformer_engine;
multi_tensor_scale::multi_tensor_scale_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), scale, device_id,
stream);
}
......@@ -196,7 +196,7 @@ void nvte_multi_tensor_sgd_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor
using namespace transformer_engine;
multi_tensor_sgd::multi_tensor_sgd_cuda(
chunk_size, *reinterpret_cast<Tensor*>(noop_flag),
chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), wd, momentum,
dampening, lr, nesterov, first_run, wd_after_momentum, scale, device_id, stream);
}
......@@ -105,11 +105,11 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
// Compute FP8 transpose if required
if (z->has_columnwise_data() && is_tensor_scaling(z->scaling_mode)) {
Tensor transpose_data;
transpose_data.data = z->columnwise_data;
transpose_data.scaling_mode = z->scaling_mode;
nvte_transpose(reinterpret_cast<NVTETensor>(z), reinterpret_cast<NVTETensor>(&transpose_data),
stream);
NVTETensor transpose_data = nvte_create_tensor(z->scaling_mode);
Tensor& t = *convertNVTETensor(transpose_data);
t.data = z->columnwise_data;
nvte_transpose(static_cast<NVTETensor>(*z), transpose_data, stream);
nvte_destroy_tensor(transpose_data);
}
return;
......@@ -195,11 +195,10 @@ void nvte_layernorm_fwd(const NVTETensor x, // BxSxhidden_size
const bool zero_centered_gamma, cudaStream_t stream) {
NVTE_API_CALL(nvte_layernorm_fwd);
using namespace transformer_engine;
layernorm_fwd(*reinterpret_cast<const Tensor*>(x), *reinterpret_cast<const Tensor*>(gamma),
*reinterpret_cast<const Tensor*>(beta), epsilon, reinterpret_cast<Tensor*>(z),
reinterpret_cast<Tensor*>(mu), reinterpret_cast<Tensor*>(rsigma),
reinterpret_cast<Tensor*>(workspace), multiprocessorCount, zero_centered_gamma,
stream);
layernorm_fwd(*convertNVTETensorCheck(x), *convertNVTETensorCheck(gamma),
*convertNVTETensorCheck(beta), epsilon, convertNVTETensor(z), convertNVTETensor(mu),
convertNVTETensor(rsigma), convertNVTETensor(workspace), multiprocessorCount,
zero_centered_gamma, stream);
}
void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size
......@@ -212,10 +211,9 @@ void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size
cudaStream_t stream) {
NVTE_API_CALL(nvte_layernorm_bwd);
using namespace transformer_engine;
layernorm_bwd(*reinterpret_cast<const Tensor*>(dz), *reinterpret_cast<const Tensor*>(x),
*reinterpret_cast<const Tensor*>(mu), *reinterpret_cast<const Tensor*>(rsigma),
*reinterpret_cast<const Tensor*>(gamma), reinterpret_cast<Tensor*>(dx),
reinterpret_cast<Tensor*>(dgamma), reinterpret_cast<Tensor*>(dbeta),
reinterpret_cast<Tensor*>(workspace), multiprocessorCount, zero_centered_gamma,
stream);
layernorm_bwd(*convertNVTETensorCheck(dz), *convertNVTETensorCheck(x),
*convertNVTETensorCheck(mu), *convertNVTETensorCheck(rsigma),
*convertNVTETensorCheck(gamma), convertNVTETensor(dx), convertNVTETensor(dgamma),
convertNVTETensor(dbeta), convertNVTETensor(workspace), multiprocessorCount,
zero_centered_gamma, stream);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment