Unverified Commit cd11e00d authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Add missing docs for C API (#1803)



* Add missing docs for C API
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Grammar, typos, copy-paste errors
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* remove contiguous word
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Better wording
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d35afe12
..
Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
cast_transpose_noop.h
=====================
.. doxygenfile:: cast_transpose_noop.h
..
Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
comm_gemm_overlap.h
===================
.. doxygenfile:: comm_gemm_overlap.h
..
Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
cudnn.h
=======
.. doxygenfile:: cudnn.h
......@@ -14,10 +14,14 @@ directly from C/C++, without Python.
transformer_engine.h <transformer_engine>
activation.h <activation>
cast_transpose_noop.h <cast_transpose_noop>
cast.h <cast>
comm_gemm_overlap.h <comm_gemm_overlap>
cudnn.h <cudnn>
fused_attn.h <fused_attn>
fused_rope.h <fused_rope>
gemm.h <gemm>
multi_tensor.h <multi_tensor>
normalization.h <normalization>
padding.h <padding>
permutation.h <permutation>
......
..
Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
multi_tensor.h
==============
.. doxygenfile:: multi_tensor.h
......@@ -17,23 +17,21 @@
extern "C" {
#endif
/*! \brief Transposes the input, providing the option to immediately exit the kernel
* based on the value of the 'noop' tensor.
/*! \brief Transposes the input.
*
* \param[in] input Input tensor.
* \param[in] noop Noop tensor.
* \param[in] input Input tensor to be cast.
* \param[in] noop If this single element tensor has non-zero value, kernel will exit immediately.
* \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output,
cudaStream_t stream);
/*! \brief Casts and transposes the input, providing the option to immediately exit the kernel
* based on the value of the 'noop' tensor.
/*! \brief Casts and transposes the input.
*
* \param[in] input Input tensor.
* \param[in] noop Noop tensor.
* \param[in,out] output Output tensor.
* \param[in] input Input tensor to be cast.
* \param[in] noop If this single element tensor has non-zero value, kernel will exit immediately.
* \param[in,out] output Output quantized tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_cast_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output,
......
......@@ -580,6 +580,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
cudaStream_t stream);
/*! \brief Update the RNG state with the seed and calculated offset.
*
* \warning This API is **experimental** and subject to change.
*
* \param[in] rng_state_dst RNG state to store seed and offset.
* \param[in] seed Seed for RNG state.
......@@ -595,6 +597,8 @@ void nvte_populate_rng_state_async(NVTETensor rng_state_dst, const NVTETensor se
NVTE_Fused_Attn_Backend backend, cudaStream_t stream);
/*! \brief Get KV format for a given QKV layout.
*
* \warning This API is **experimental** and subject to change.
*
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[in] workspace Workspace tensor.
......@@ -604,48 +608,187 @@ void nvte_populate_rng_state_async(NVTETensor rng_state_dst, const NVTETensor se
uint32_t nvte_get_runtime_num_segments(NVTETensor cu_seqlen, NVTETensor workspace, size_t len,
cudaStream_t stream);
/*! \brief Set the seed and offset for RNG state.
*
* \warning This API is **experimental** and subject to change.
*
* \param[out] rng_state_ptr A size 2 array storing the RNG's seed and offset respectively.
* \param[in] captured Whether a CUDA graph is being captured.
* \param[in] seed_ptr Seed pointer.
* \param[in] seed_val Seed value.
* \param[in] offset_ptr Offset pointer.
* \param[in] offset_val Offset value.
* \param[in] offset_intragraph Intragraph offset in RNG states. For use with CUDA Graphs.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_extract_seed_and_offset(int64_t *rng_state_ptr, int captured, int64_t *seed_ptr,
uint64_t seed_val, int64_t *offset_ptr, uint64_t offset_val,
uint32_t offset_intragraph, cudaStream_t stream);
/*! \brief Copy keys and values into the KV cache.
*
* \warning This API is **experimental** and subject to change.
*
* \param[in] new_k Key tensor.
* \param[in] new_v Value tensor.
* \param[out] k_cache Key cache.
* \param[out] v_cache Value cache.
* \param[in] page_table Page table for K cache, [batch_size, max_pages_per_seq].
* \param[in] cu_new_lens Cumulative sequence lengths.
* \param[in] cu_cached_lens Cached cumulative sequence lengths.
* \param[in] qkv_format QKV format, e.g. sbhd.
* \param[in] b Batch size.
* \param[in] max_ctx_len Maximum context length.
* \param[in] max_seq_len Maximum sequence length.
* \param[in] max_pages_per_seq Maximum number of pages per sequence.
* \param[in] is_non_paged Whether the cache is paged or not.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_copy_to_kv_cache(NVTETensor new_k, NVTETensor new_v, NVTETensor k_cache,
NVTETensor v_cache, NVTETensor page_table, NVTETensor cu_new_lens,
NVTETensor cu_cached_lens, NVTE_QKV_Format qkv_format, int b,
int max_ctx_len, int max_seq_len, int max_pages_per_seq,
int is_non_paged, cudaStream_t stream);
/*! \brief Extract the first half (half_idx=0) or second half (half_idx=1) of a THD tensor.
*
* \warning This API is **experimental** and subject to change.
*
* \param[in] tensor Input tensor.
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[out] half Output tensor.
* \param[in] half_idx Whether to read first or second half of input tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_cp_thd_read_half_tensor(const NVTETensor &tensor, const NVTETensor &cu_seqlens,
NVTETensor half, int half_idx, cudaStream_t stream);
/*! \brief Correct the second half of the softmax LSE (LogSumExp) for context parallelism.
*
* \warning This API is **experimental** and subject to change.
*
* \param[out] lse Output tensor.
* \param[in] lse_per_step Input tensor.
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[in] lse_packed Whether or not lse_per_step is packed.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_cp_thd_second_half_lse_correction(NVTETensor lse, const NVTETensor &lse_per_step,
const NVTETensor &cu_seqlens, int lse_packed,
cudaStream_t stream);
/*! \brief Read the second half of the softmax LSE (LogSumExp) for context parallelism.
*
* \warning This API is **experimental** and subject to change.
*
* \param[in] lse Input tensor.
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[out] half_lse Output tensor.
* \param[in] lse_packed Whether or the softmax LSE is in packed format.
* \param[in] second_half_lse_seqlen Sequence length.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_cp_thd_read_second_half_lse(const NVTETensor &lse, const NVTETensor &cu_seqlens,
NVTETensor half_lse, int lse_packed,
int second_half_lse_seqlen, cudaStream_t stream);
/*! \brief Correct the THD format output of context parallelism in forward pass.
*
* \warning This API is **experimental** and subject to change.
*
* \param[out] out Output tensor.
* \param[in] out_per_step THD format output of context parallelism in forward pass.
* \param[in] lse Softmax LSE.
* \param[in] lse_per_step Softmax LSE per step.
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[in] only_second_half Whether or not to correct only second half.
* \param[in] lse_packed Whether or the softmax LSE is in packed format.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_cp_thd_out_correction(NVTETensor out, const NVTETensor &out_per_step,
const NVTETensor &lse, const NVTETensor &lse_per_step,
const NVTETensor &cu_seqlens, int only_second_half, int lse_packed,
cudaStream_t stream);
/*! \brief Correct the THD format output of context parallelism in forward pass.
*
* \warning This API is **experimental** and subject to change.
*
* \param[out] grad Output tensor.
* \param[in] grad_per_step THD format gradient of context parallelism.
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[in] first_half One of ("add", "copy", "none") correction op for first half.
* \param[in] second_half One of ("add", "copy", "none") correction op for second half.
Must be different from first_half.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_cp_thd_grad_correction(NVTETensor grad, const NVTETensor &grad_per_step,
const NVTETensor &cu_seqlens, const char *first_half,
const char *second_half, cudaStream_t stream);
/*! \brief Generate partitioned indices for inputs in THD format.
*
* \warning This API is **experimental** and subject to change.
*
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[out] output Output tensor.
* \param[in] total_tokens Total number of tokens.
* \param[in] world_size Total number of devices for context parallelism.
* \param[in] rank Device ID for current device.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_cp_thd_get_partitioned_indices(const NVTETensor &cu_seqlens, NVTETensor output,
int total_tokens, int world_size, int rank,
cudaStream_t stream);
/*! \brief Convert tensor from THD to BSHD format.
*
* \warning This API is **experimental** and subject to change.
*
* \param[in] tensor Input tensor.
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[out] new_tensor Output tensor.
* \param[in] b Batch size.
* \param[in] max_seq_len Maximum sequence length.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_convert_thd_to_bshd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETensor new_tensor,
int b, int max_seq_len, cudaStream_t stream);
/*! \brief Convert tensor from BSHD to THD format.
*
* \warning This API is **experimental** and subject to change.
*
* \param[in] tensor Input tensor.
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[out] new_tensor Output tensor.
* \param[in] b Batch size.
* \param[in] max_seq_len Maximum sequence length.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_convert_bshd_to_thd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETensor new_tensor,
int t, cudaStream_t stream);
/*! \brief Prepare QKV tensor for Flash Attention forward kernel.
*
* \warning This API is **experimental** and subject to change.
*
* \param[in] qkvi Input tensor.
* \param[out] qkv Output tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_prepare_flash_attn_fwd(NVTETensor qkvi, NVTETensor qkv, cudaStream_t stream);
/*! \brief Prepare QKV tensor for Flash Attention backward kernel.
*
* \warning This API is **experimental** and subject to change.
*
* \param[in] q Input query tensor.
* \param[in] k Input key tensor.
* \param[in] v Input value tensor.
* \param[out] qkv Output tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_prepare_flash_attn_bwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTETensor qkv,
cudaStream_t stream);
......
......@@ -17,6 +17,25 @@
extern "C" {
#endif
/*! \brief Computes L2 norm for a list of tensors.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
* \param[in] tensor_lists 2D array of input tensors.
* \param[in] num_tensor_lists Size (dim0) of tensor_lists.
* \param[in] num_tensors_per_list Size (dim1) of tensor_lists.
* \param[in] output Scratch space. Required size grows with number of inputs.
* \param[in] output_per_tensor Fixed size auxilliary scratch space.
* \param[out] ret L2 norm of all inputs.
* \param[out] ret_per_tensor L2 norm for each tensor.
* \param[in] per_tensor Whether to calculate per tensor or cumulative norm.
* \param[in] max_chunks_per_tensor Maximum number of chunks in any input tensor.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list,
NVTETensor output, NVTETensor output_per_tensor, NVTETensor ret,
......@@ -24,6 +43,28 @@ void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETen
int max_chunks_per_tensor, const int device_id,
cudaStream_t stream);
/*! \brief Computes L2 norm for a list of tensors after unscaling.
*
* Unscaling is only done for computing the L2 norm. The tensors themselves are not updated.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
* \param[in] tensor_lists 2D array of input tensors.
* \param[in] num_tensor_lists Size (dim0) of tensor_lists.
* \param[in] num_tensors_per_list Size (dim1) of tensor_lists.
* \param[in] output Scratch space. Required size grows with number of inputs.
* \param[in] output_per_tensor Fixed size auxilliary scratch space.
* \param[out] ret L2 norm of all inputs.
* \param[out] ret_per_tensor L2 norm for each tensor.
* \param[in] inv_scale Scalar for the unscaling operation.
* \param[in] per_tensor Whether to calculate per tensor or cumulative norm.
* \param[in] max_chunks_per_tensor Maximum number of chunks in any input tensor.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, NVTETensor output,
......@@ -32,6 +73,27 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
int per_tensor, int max_chunks_per_tensor,
const int device_id, cudaStream_t stream);
/*! \brief Compute and apply gradient update to parameters for Adam optimizer.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
* \param[in,out] tensor_lists 2D array of input tensors.
* \param[in] num_tensor_lists Size (dim0) of tensor_lists.
* \param[in] num_tensors_per_list Size (dim1) of tensor_lists.
* \param[in] lr Learning rate.
* \param[in] beta1 Coefficient for first moment of gradient.
* \param[in] beta2 Coefficient for second moment of gradient.
* \param[in] epsilon Term added to the denominator for numerical stability.
* \param[in] step Iteration counter.
* \param[in] mode Whether to use AdamW (L2 penalty applied to params).
* \param[in] bias_correction Whether to apply correction factor for moment estimates.
* \param[in] weight_decay L2 penalty for weight decay.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list,
const float lr, const float beta1, const float beta2,
......@@ -39,12 +101,57 @@ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETenso
const int bias_correction, const float weight_decay,
const int device_id, cudaStream_t stream);
/*! \brief Compute and apply gradient update to parameters for Adam optimizer
* where the master parameters only store the remainder bits.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
* \param[in,out] tensor_lists 2D array of input tensors.
* \param[in] num_tensor_lists Size (dim0) of tensor_lists.
* \param[in] num_tensors_per_list Size (dim1) of tensor_lists.
* \param[in] lr Learning rate.
* \param[in] beta1 Coefficient for first moment of gradient.
* \param[in] beta2 Coefficient for second moment of gradient.
* \param[in] epsilon Term added to the denominator for numerical stability.
* \param[in] step Iteration counter.
* \param[in] mode Whether to use AdamW (L2 penalty applied to params).
* \param[in] bias_correction Whether to apply correction factor for moment estimates.
* \param[in] weight_decay L2 penalty for weight decay.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_multi_tensor_adam_param_remainder_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode, const int bias_correction,
const float weight_decay, const int device_id, cudaStream_t stream);
/*! \brief Compute and apply gradient update to parameters for Adam optimizer
* when model parameters are in Float8 precision.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
* \param[in,out] tensor_lists 2D array of input tensors.
* \param[in] num_tensor_lists Size (dim0) of tensor_lists.
* \param[in] num_tensors_per_list Size (dim1) of tensor_lists.
* \param[in] lr Learning rate.
* \param[in] beta1 Coefficient for first moment of gradient.
* \param[in] beta2 Coefficient for second moment of gradient.
* \param[in] epsilon Term added to the denominator for numerical stability.
* \param[in] step Iteration counter.
* \param[in] mode Whether to use AdamW (L2 penalty applied to params).
* \param[in] bias_correction Whether to apply correction factor for moment estimates.
* \param[in] weight_decay L2 penalty for weight decay.
* \param[in] fp8_dtype FP8 data type for model parameters.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag,
NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, const float lr,
......@@ -53,28 +160,125 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag,
const float weight_decay, const NVTEDType fp8_dtype,
const int device_id, cudaStream_t stream);
/*! \brief Compute and apply gradient update to parameters for Adam optimizer
* with CUDA graph support and LR scheduling.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
* \param[in,out] tensor_lists 2D array of input tensors.
* \param[in] num_tensor_lists Size (dim0) of tensor_lists.
* \param[in] num_tensors_per_list Size (dim1) of tensor_lists.
* \param[in] lr Learning rate.
* \param[in] beta1 Coefficient for first moment of gradient.
* \param[in] beta2 Coefficient for second moment of gradient.
* \param[in] epsilon Term added to the denominator for numerical stability.
* \param[in] step Iteration counter.
* \param[in] mode Whether to use AdamW (L2 penalty applied to params).
* \param[in] bias_correction Whether to apply correction factor for moment estimates.
* \param[in] weight_decay L2 penalty for weight decay.
* \param[in] inv_scale Scalar for the unscaling operation.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_multi_tensor_adam_capturable_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2,
const float epsilon, NVTETensor step, const int mode, const int bias_correction,
const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream);
/*! \brief Compute and apply gradient update to parameters for Adam optimizer
* with CUDA graph support, LR scheduling, and FP32 master weights.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
* \param[in,out] tensor_lists 2D array of input tensors.
* \param[in] num_tensor_lists Size (dim0) of tensor_lists.
* \param[in] num_tensors_per_list Size (dim1) of tensor_lists.
* \param[in] lr Learning rate.
* \param[in] beta1 Coefficient for first moment of gradient.
* \param[in] beta2 Coefficient for second moment of gradient.
* \param[in] epsilon Term added to the denominator for numerical stability.
* \param[in] step Iteration counter.
* \param[in] mode Whether to use AdamW (L2 penalty applied to params).
* \param[in] bias_correction Whether to apply correction factor for moment estimates.
* \param[in] weight_decay L2 penalty for weight decay.
* \param[in] inv_scale Scalar for the unscaling operation.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_multi_tensor_adam_capturable_master_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2,
const float epsilon, NVTETensor step, const int mode, const int bias_correction,
const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream);
/*! \brief Compute and apply gradient update to parameters for SGD optimizer.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
* \param[in,out] tensor_lists 2D array of input tensors.
* \param[in] num_tensor_lists Size (dim0) of tensor_lists.
* \param[in] num_tensors_per_list Size (dim1) of tensor_lists.
* \param[in] wd Weight decay (L2 penalty).
* \param[in] momentum Momentum factor.
* \param[in] dampening Dampening factor.
* \param[in] lr Learning rate.
* \param[in] nesterov Whether or not to enable nesterov momentum.
* \param[in] first_run Whether momentum buffers have been initialized.
* \param[in] wd_after_momentum Whether to applied weight decay after momentum update.
* \param[in] scale Scalar for the scaling operation.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_multi_tensor_sgd_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list,
float wd, float momentum, float dampening, float lr, int nesterov,
int first_run, int wd_after_momentum, float scale,
const int device_id, cudaStream_t stream);
/*! \brief Check overflow and scale a list of tensors.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
* \param[in,out] tensor_lists 2D array of input tensors.
* \param[in] num_tensor_lists Size (dim0) of tensor_lists.
* \param[in] num_tensors_per_list Size (dim1) of tensor_lists.
* \param[in] scale Scalar for the scaling operation.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list,
float scale, const int device_id, cudaStream_t stream);
/*! \brief Check overflow and scale a list of tensors.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
* \param[in,out] tensor_lists 2D array of input tensors.
* \param[in] num_tensor_lists Size (dim0) of tensor_lists.
* \param[in] num_tensors_per_list Size (dim1) of tensor_lists.
* \param[in] max_fp8 Maximum representible value in underlying FP8 format.
* \param[in] force_pow_2_scales Ensure scaling factors are a power of 2.
* \param[in] epsilon Term added to the denominator for numerical stability.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_multi_tensor_compute_scale_and_scale_inv_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, float max_fp8, int force_pow_2_scales, float epsilon,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment