"googlemock/vscode:/vscode.git/clone" did not exist on "e4717df71a4f45bf9f0ac88c6cd9846a0bc248dd"
Unverified Commit 1470116e authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[C][PyTorch] Remove deprecated `device_id` arg for multi tensor API (#1994)



* Remove deprecated device arg
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Remove test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 374849e3
...@@ -112,13 +112,6 @@ class TestFusedAdam(TestFusedOptimizer): ...@@ -112,13 +112,6 @@ class TestFusedAdam(TestFusedOptimizer):
def test_bfloat16(self): def test_bfloat16(self):
self.gen_single_type_test(param_type=torch.bfloat16, skip_assert=True) self.gen_single_type_test(param_type=torch.bfloat16, skip_assert=True)
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="more than 1 GPU required")
def test_multi_device(self):
devices = ("cuda:0", "cuda:1")
for current_dev, tensor_dev in product(devices, devices):
with torch.cuda.device(current_dev):
self.gen_single_type_test(param_type=torch.float, device=tensor_dev)
def test_multi_params(self): def test_multi_params(self):
sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]] sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]
...@@ -530,13 +523,6 @@ class TestFusedSGD(TestFusedOptimizer): ...@@ -530,13 +523,6 @@ class TestFusedSGD(TestFusedOptimizer):
def test_half(self): def test_half(self):
self.gen_single_type_test(param_type=torch.float16) self.gen_single_type_test(param_type=torch.float16)
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="more than 1 GPU required")
def test_multi_device(self):
devices = ("cuda:0", "cuda:1")
for current_dev, tensor_dev in product(devices, devices):
with torch.cuda.device(current_dev):
self.gen_single_type_test(param_type=torch.float, device=tensor_dev)
class Model(torch.nn.Module): class Model(torch.nn.Module):
def __init__(self): def __init__(self):
......
...@@ -20,7 +20,6 @@ extern "C" { ...@@ -20,7 +20,6 @@ extern "C" {
/*! \brief Computes L2 norm for a list of tensors. /*! \brief Computes L2 norm for a list of tensors.
* *
* \warning This API is **experimental** and subject to change. * \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
* *
* \param[in] chunk_size Number of tensor elements processed by a CUDA block. * \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
...@@ -33,22 +32,19 @@ extern "C" { ...@@ -33,22 +32,19 @@ extern "C" {
* \param[out] ret_per_tensor L2 norm for each tensor. * \param[out] ret_per_tensor L2 norm for each tensor.
* \param[in] per_tensor Whether to calculate per tensor or cumulative norm. * \param[in] per_tensor Whether to calculate per tensor or cumulative norm.
* \param[in] max_chunks_per_tensor Maximum number of chunks in any input tensor. * \param[in] max_chunks_per_tensor Maximum number of chunks in any input tensor.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list, const size_t num_tensor_lists, const size_t num_tensors_per_list,
NVTETensor output, NVTETensor output_per_tensor, NVTETensor ret, NVTETensor output, NVTETensor output_per_tensor, NVTETensor ret,
NVTETensor ret_per_tensor, int per_tensor, NVTETensor ret_per_tensor, int per_tensor,
int max_chunks_per_tensor, const int device_id, int max_chunks_per_tensor, cudaStream_t stream);
cudaStream_t stream);
/*! \brief Computes L2 norm for a list of tensors after unscaling. /*! \brief Computes L2 norm for a list of tensors after unscaling.
* *
* Unscaling is only done for computing the L2 norm. The tensors themselves are not updated. * Unscaling is only done for computing the L2 norm. The tensors themselves are not updated.
* *
* \warning This API is **experimental** and subject to change. * \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
* *
* \param[in] chunk_size Number of tensor elements processed by a CUDA block. * \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
...@@ -62,7 +58,6 @@ void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETen ...@@ -62,7 +58,6 @@ void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETen
* \param[in] inv_scale Scalar for the unscaling operation. * \param[in] inv_scale Scalar for the unscaling operation.
* \param[in] per_tensor Whether to calculate per tensor or cumulative norm. * \param[in] per_tensor Whether to calculate per tensor or cumulative norm.
* \param[in] max_chunks_per_tensor Maximum number of chunks in any input tensor. * \param[in] max_chunks_per_tensor Maximum number of chunks in any input tensor.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag, void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
...@@ -71,12 +66,11 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag, ...@@ -71,12 +66,11 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
NVTETensor output_per_tensor, NVTETensor ret, NVTETensor output_per_tensor, NVTETensor ret,
NVTETensor ret_per_tensor, NVTETensor inv_scale, NVTETensor ret_per_tensor, NVTETensor inv_scale,
int per_tensor, int max_chunks_per_tensor, int per_tensor, int max_chunks_per_tensor,
const int device_id, cudaStream_t stream); cudaStream_t stream);
/*! \brief Compute and apply gradient update to parameters for Adam optimizer. /*! \brief Compute and apply gradient update to parameters for Adam optimizer.
* *
* \warning This API is **experimental** and subject to change. * \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
* *
* \param[in] chunk_size Number of tensor elements processed by a CUDA block. * \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
...@@ -91,7 +85,6 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag, ...@@ -91,7 +85,6 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
* \param[in] mode Whether to use AdamW (L2 penalty applied to params). * \param[in] mode Whether to use AdamW (L2 penalty applied to params).
* \param[in] bias_correction Whether to apply correction factor for moment estimates. * \param[in] bias_correction Whether to apply correction factor for moment estimates.
* \param[in] weight_decay L2 penalty for weight decay. * \param[in] weight_decay L2 penalty for weight decay.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
...@@ -99,13 +92,12 @@ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETenso ...@@ -99,13 +92,12 @@ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETenso
const float lr, const float beta1, const float beta2, const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode, const float epsilon, const int step, const int mode,
const int bias_correction, const float weight_decay, const int bias_correction, const float weight_decay,
const int device_id, cudaStream_t stream); cudaStream_t stream);
/*! \brief Compute and apply gradient update to parameters for Adam optimizer /*! \brief Compute and apply gradient update to parameters for Adam optimizer
* where the master parameters only store the remainder bits. * where the master parameters only store the remainder bits.
* *
* \warning This API is **experimental** and subject to change. * \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
* *
* \param[in] chunk_size Number of tensor elements processed by a CUDA block. * \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
...@@ -120,20 +112,18 @@ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETenso ...@@ -120,20 +112,18 @@ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETenso
* \param[in] mode Whether to use AdamW (L2 penalty applied to params). * \param[in] mode Whether to use AdamW (L2 penalty applied to params).
* \param[in] bias_correction Whether to apply correction factor for moment estimates. * \param[in] bias_correction Whether to apply correction factor for moment estimates.
* \param[in] weight_decay L2 penalty for weight decay. * \param[in] weight_decay L2 penalty for weight decay.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
void nvte_multi_tensor_adam_param_remainder_cuda( void nvte_multi_tensor_adam_param_remainder_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, const float lr, const float beta1, const float beta2, const size_t num_tensors_per_list, const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode, const int bias_correction, const float epsilon, const int step, const int mode, const int bias_correction,
const float weight_decay, const int device_id, cudaStream_t stream); const float weight_decay, cudaStream_t stream);
/*! \brief Compute and apply gradient update to parameters for Adam optimizer /*! \brief Compute and apply gradient update to parameters for Adam optimizer
* when model parameters are in Float8 precision. * when model parameters are in Float8 precision.
* *
* \warning This API is **experimental** and subject to change. * \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
* *
* \param[in] chunk_size Number of tensor elements processed by a CUDA block. * \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
...@@ -149,7 +139,6 @@ void nvte_multi_tensor_adam_param_remainder_cuda( ...@@ -149,7 +139,6 @@ void nvte_multi_tensor_adam_param_remainder_cuda(
* \param[in] bias_correction Whether to apply correction factor for moment estimates. * \param[in] bias_correction Whether to apply correction factor for moment estimates.
* \param[in] weight_decay L2 penalty for weight decay. * \param[in] weight_decay L2 penalty for weight decay.
* \param[in] fp8_dtype FP8 data type for model parameters. * \param[in] fp8_dtype FP8 data type for model parameters.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag, void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag,
...@@ -158,13 +147,12 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag, ...@@ -158,13 +147,12 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag,
const float beta1, const float beta2, const float epsilon, const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction, const int step, const int mode, const int bias_correction,
const float weight_decay, const NVTEDType fp8_dtype, const float weight_decay, const NVTEDType fp8_dtype,
const int device_id, cudaStream_t stream); cudaStream_t stream);
/*! \brief Compute and apply gradient update to parameters for Adam optimizer /*! \brief Compute and apply gradient update to parameters for Adam optimizer
* with CUDA graph support and LR scheduling. * with CUDA graph support and LR scheduling.
* *
* \warning This API is **experimental** and subject to change. * \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
* *
* \param[in] chunk_size Number of tensor elements processed by a CUDA block. * \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
...@@ -180,20 +168,18 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag, ...@@ -180,20 +168,18 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag,
* \param[in] bias_correction Whether to apply correction factor for moment estimates. * \param[in] bias_correction Whether to apply correction factor for moment estimates.
* \param[in] weight_decay L2 penalty for weight decay. * \param[in] weight_decay L2 penalty for weight decay.
* \param[in] inv_scale Scalar for the unscaling operation. * \param[in] inv_scale Scalar for the unscaling operation.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
void nvte_multi_tensor_adam_capturable_cuda( void nvte_multi_tensor_adam_capturable_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2, const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2,
const float epsilon, NVTETensor step, const int mode, const int bias_correction, const float epsilon, NVTETensor step, const int mode, const int bias_correction,
const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream); const float weight_decay, NVTETensor inv_scale, cudaStream_t stream);
/*! \brief Compute and apply gradient update to parameters for Adam optimizer /*! \brief Compute and apply gradient update to parameters for Adam optimizer
* with CUDA graph support, LR scheduling, and FP32 master weights. * with CUDA graph support, LR scheduling, and FP32 master weights.
* *
* \warning This API is **experimental** and subject to change. * \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
* *
* \param[in] chunk_size Number of tensor elements processed by a CUDA block. * \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
...@@ -209,19 +195,17 @@ void nvte_multi_tensor_adam_capturable_cuda( ...@@ -209,19 +195,17 @@ void nvte_multi_tensor_adam_capturable_cuda(
* \param[in] bias_correction Whether to apply correction factor for moment estimates. * \param[in] bias_correction Whether to apply correction factor for moment estimates.
* \param[in] weight_decay L2 penalty for weight decay. * \param[in] weight_decay L2 penalty for weight decay.
* \param[in] inv_scale Scalar for the unscaling operation. * \param[in] inv_scale Scalar for the unscaling operation.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
void nvte_multi_tensor_adam_capturable_master_cuda( void nvte_multi_tensor_adam_capturable_master_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2, const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2,
const float epsilon, NVTETensor step, const int mode, const int bias_correction, const float epsilon, NVTETensor step, const int mode, const int bias_correction,
const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream); const float weight_decay, NVTETensor inv_scale, cudaStream_t stream);
/*! \brief Compute and apply gradient update to parameters for SGD optimizer. /*! \brief Compute and apply gradient update to parameters for SGD optimizer.
* *
* \warning This API is **experimental** and subject to change. * \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
* *
* \param[in] chunk_size Number of tensor elements processed by a CUDA block. * \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
...@@ -236,19 +220,17 @@ void nvte_multi_tensor_adam_capturable_master_cuda( ...@@ -236,19 +220,17 @@ void nvte_multi_tensor_adam_capturable_master_cuda(
* \param[in] first_run Whether momentum buffers have been initialized. * \param[in] first_run Whether momentum buffers have been initialized.
* \param[in] wd_after_momentum Whether to applied weight decay after momentum update. * \param[in] wd_after_momentum Whether to applied weight decay after momentum update.
* \param[in] scale Scalar for the scaling operation. * \param[in] scale Scalar for the scaling operation.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
void nvte_multi_tensor_sgd_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, void nvte_multi_tensor_sgd_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list, const size_t num_tensor_lists, const size_t num_tensors_per_list,
float wd, float momentum, float dampening, float lr, int nesterov, float wd, float momentum, float dampening, float lr, int nesterov,
int first_run, int wd_after_momentum, float scale, int first_run, int wd_after_momentum, float scale,
const int device_id, cudaStream_t stream); cudaStream_t stream);
/*! \brief Check overflow and scale a list of tensors. /*! \brief Check overflow and scale a list of tensors.
* *
* \warning This API is **experimental** and subject to change. * \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
* *
* \param[in] chunk_size Number of tensor elements processed by a CUDA block. * \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
...@@ -256,17 +238,15 @@ void nvte_multi_tensor_sgd_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor ...@@ -256,17 +238,15 @@ void nvte_multi_tensor_sgd_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor
* \param[in] num_tensor_lists Size (dim0) of tensor_lists. * \param[in] num_tensor_lists Size (dim0) of tensor_lists.
* \param[in] num_tensors_per_list Size (dim1) of tensor_lists. * \param[in] num_tensors_per_list Size (dim1) of tensor_lists.
* \param[in] scale Scalar for the scaling operation. * \param[in] scale Scalar for the scaling operation.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list, const size_t num_tensor_lists, const size_t num_tensors_per_list,
float scale, const int device_id, cudaStream_t stream); float scale, cudaStream_t stream);
/*! \brief Check overflow and scale a list of tensors. /*! \brief Check overflow and scale a list of tensors.
* *
* \warning This API is **experimental** and subject to change. * \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
* *
* \param[in] chunk_size Number of tensor elements processed by a CUDA block. * \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
...@@ -276,13 +256,14 @@ void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETens ...@@ -276,13 +256,14 @@ void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETens
* \param[in] max_fp8 Maximum representible value in underlying FP8 format. * \param[in] max_fp8 Maximum representible value in underlying FP8 format.
* \param[in] force_pow_2_scales Ensure scaling factors are a power of 2. * \param[in] force_pow_2_scales Ensure scaling factors are a power of 2.
* \param[in] epsilon Term added to the denominator for numerical stability. * \param[in] epsilon Term added to the denominator for numerical stability.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
void nvte_multi_tensor_compute_scale_and_scale_inv_cuda( void nvte_multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, NVTETensor noop_flag,
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, NVTETensor **tensor_lists,
const size_t num_tensors_per_list, float max_fp8, int force_pow_2_scales, float epsilon, const size_t num_tensor_lists,
const int device_id, cudaStream_t stream); const size_t num_tensors_per_list,
float max_fp8, int force_pow_2_scales,
float epsilon, cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
......
...@@ -576,7 +576,7 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, ...@@ -576,7 +576,7 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor *>> tensor_lists, const float lr, std::vector<std::vector<Tensor *>> tensor_lists, const float lr,
const float beta1, const float beta2, const float epsilon, const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction, const int step, const int mode, const int bias_correction,
const float weight_decay, const int device_id, cudaStream_t stream) { const float weight_decay, cudaStream_t stream) {
// Handle bias correction mode // Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f; float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1) { if (bias_correction == 1) {
...@@ -643,20 +643,20 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, ...@@ -643,20 +643,20 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
g_in_type_te, g_in_type, g_in_type_te, g_in_type,
multi_tensor_apply<4>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, multi_tensor_apply<4>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag,
tensor_lists, tensor_lists,
AdamFunctor<p_in_type, g_in_type, float, int64_t>(), device_id, AdamFunctor<p_in_type, g_in_type, float, int64_t>(), stream,
stream, beta1, beta2, bias_correction1, bias_correction2, beta1, beta2, bias_correction1, bias_correction2, epsilon, lr,
epsilon, lr, (adamMode_t)mode, weight_decay);)); (adamMode_t)mode, weight_decay);));
} else { } else {
// g, p, m, v, p_master // g, p, m, v, p_master
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
p_in_type_te, p_in_type, p_in_type_te, p_in_type,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
g_in_type_te, g_in_type, g_in_type_te, g_in_type,
multi_tensor_apply<5>( multi_tensor_apply<5>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag,
(int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, tensor_lists,
AdamFunctorMaster<p_in_type, g_in_type, float, int64_t>(), device_id, stream, AdamFunctorMaster<p_in_type, g_in_type, float, int64_t>(),
beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, stream, beta1, beta2, bias_correction1, bias_correction2,
weight_decay);)); epsilon, lr, (adamMode_t)mode, weight_decay);));
} }
} else { } else {
if (num_tensor_lists == 4) { if (num_tensor_lists == 4) {
...@@ -666,9 +666,9 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, ...@@ -666,9 +666,9 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
g_in_type_te, g_in_type, g_in_type_te, g_in_type,
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamFunctor<p_in_type, g_in_type, float, int32_t>(), device_id, AdamFunctor<p_in_type, g_in_type, float, int32_t>(), stream,
stream, beta1, beta2, bias_correction1, bias_correction2, beta1, beta2, bias_correction1, bias_correction2, epsilon, lr,
epsilon, lr, (adamMode_t)mode, weight_decay);)); (adamMode_t)mode, weight_decay);));
} else { } else {
// g, p, m, v, p_master // g, p, m, v, p_master
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
...@@ -677,9 +677,8 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, ...@@ -677,9 +677,8 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
g_in_type_te, g_in_type, g_in_type_te, g_in_type,
multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamFunctorMaster<p_in_type, g_in_type, float, int32_t>(), AdamFunctorMaster<p_in_type, g_in_type, float, int32_t>(),
device_id, stream, beta1, beta2, bias_correction1, stream, beta1, beta2, bias_correction1, bias_correction2,
bias_correction2, epsilon, lr, (adamMode_t)mode, epsilon, lr, (adamMode_t)mode, weight_decay);));
weight_decay);));
} }
} }
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
...@@ -690,7 +689,7 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, Tensor noop_flag, ...@@ -690,7 +689,7 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, Tensor noop_flag,
const float lr, const float beta1, const float beta2, const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode, const float epsilon, const int step, const int mode,
const int bias_correction, const float weight_decay, const int bias_correction, const float weight_decay,
const int device_id, cudaStream_t stream) { cudaStream_t stream) {
// Handle bias correction mode // Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f; float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1) { if (bias_correction == 1) {
...@@ -732,8 +731,8 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, Tensor noop_flag, ...@@ -732,8 +731,8 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, Tensor noop_flag,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
g_in_type_te, g_in_type, g_in_type_te, g_in_type,
multi_tensor_apply<5>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, multi_tensor_apply<5>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists,
AdamFunctorMasterParamRemainder<g_in_type, float, int64_t>(), device_id, AdamFunctorMasterParamRemainder<g_in_type, float, int64_t>(), stream,
stream, beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, beta1, beta2, bias_correction1, bias_correction2, epsilon, lr,
(adamMode_t)mode, weight_decay);); (adamMode_t)mode, weight_decay););
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
} }
...@@ -743,7 +742,7 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag, ...@@ -743,7 +742,7 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag,
const float beta1, const float beta2, const float epsilon, const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction, const int step, const int mode, const int bias_correction,
const float weight_decay, const DType fp8_dtype, const float weight_decay, const DType fp8_dtype,
const int device_id, cudaStream_t stream) { cudaStream_t stream) {
// Handle bias correction mode // Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f; float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1) { if (bias_correction == 1) {
...@@ -813,9 +812,8 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag, ...@@ -813,9 +812,8 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag,
g_in_type_te, g_in_type, g_in_type_te, g_in_type,
multi_tensor_apply<5, true>( multi_tensor_apply<5, true>(
(int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, (int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists,
AdamFunctorMaster<FP8_T, g_in_type, float, int64_t>(), device_id, stream, beta1, AdamFunctorMaster<FP8_T, g_in_type, float, int64_t>(), stream, beta1, beta2,
beta2, bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, weight_decay);));
weight_decay);));
} else { } else {
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
fp8_dtype, FP8_T, fp8_dtype, FP8_T,
...@@ -823,9 +821,8 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag, ...@@ -823,9 +821,8 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag,
g_in_type_te, g_in_type, g_in_type_te, g_in_type,
multi_tensor_apply<5, true>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, multi_tensor_apply<5, true>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamFunctorMaster<FP8_T, g_in_type, float, int32_t>(), AdamFunctorMaster<FP8_T, g_in_type, float, int32_t>(),
device_id, stream, beta1, beta2, bias_correction1, stream, beta1, beta2, bias_correction1, bias_correction2,
bias_correction2, epsilon, lr, (adamMode_t)mode, epsilon, lr, (adamMode_t)mode, weight_decay);));
weight_decay);));
} }
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
} }
...@@ -835,7 +832,7 @@ void multi_tensor_adam_capturable_cuda(int chunk_size, Tensor noop_flag, ...@@ -835,7 +832,7 @@ void multi_tensor_adam_capturable_cuda(int chunk_size, Tensor noop_flag,
const float beta1, const float beta2, const float epsilon, const float beta1, const float beta2, const float epsilon,
Tensor step, const int mode, const int bias_correction, Tensor step, const int mode, const int bias_correction,
const float weight_decay, Tensor inv_scale, const float weight_decay, Tensor inv_scale,
const int device_id, cudaStream_t stream) { cudaStream_t stream) {
// Check tensor list sizes // Check tensor list sizes
// 4 tensor lists: g, p, m, v // 4 tensor lists: g, p, m, v
const size_t num_tensor_lists = tensor_lists.size(); const size_t num_tensor_lists = tensor_lists.size();
...@@ -867,7 +864,7 @@ void multi_tensor_adam_capturable_cuda(int chunk_size, Tensor noop_flag, ...@@ -867,7 +864,7 @@ void multi_tensor_adam_capturable_cuda(int chunk_size, Tensor noop_flag,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[0][0]->dtype(), dtype, tensor_lists[0][0]->dtype(), dtype,
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamCapturableFunctor<dtype, float>(), device_id, stream, beta1, beta2, AdamCapturableFunctor<dtype, float>(), stream, beta1, beta2,
reinterpret_cast<int *>(step.data.dptr), bias_correction, epsilon, reinterpret_cast<int *>(step.data.dptr), bias_correction, epsilon,
reinterpret_cast<float *>(lr.data.dptr), (adamMode_t)mode, weight_decay, reinterpret_cast<float *>(lr.data.dptr), (adamMode_t)mode, weight_decay,
reinterpret_cast<float *>(inv_scale.data.dptr));) reinterpret_cast<float *>(inv_scale.data.dptr));)
...@@ -880,8 +877,7 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, Tensor noop_flag, ...@@ -880,8 +877,7 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, Tensor noop_flag,
Tensor lr, const float beta1, const float beta2, Tensor lr, const float beta1, const float beta2,
const float epsilon, Tensor step, const int mode, const float epsilon, Tensor step, const int mode,
const int bias_correction, const float weight_decay, const int bias_correction, const float weight_decay,
Tensor inv_scale, const int device_id, Tensor inv_scale, cudaStream_t stream) {
cudaStream_t stream) {
// Check tensor list sizes // Check tensor list sizes
// 4 tensor lists: g, p, m, v, p_master // 4 tensor lists: g, p, m, v, p_master
const size_t num_tensor_lists = tensor_lists.size(); const size_t num_tensor_lists = tensor_lists.size();
...@@ -916,10 +912,10 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, Tensor noop_flag, ...@@ -916,10 +912,10 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, Tensor noop_flag,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[0][0]->dtype(), dtype, tensor_lists[0][0]->dtype(), dtype,
multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamCapturableMasterFunctor<dtype, float>(), device_id, stream, beta1, AdamCapturableMasterFunctor<dtype, float>(), stream, beta1, beta2,
beta2, reinterpret_cast<int *>(step.data.dptr), bias_correction, reinterpret_cast<int *>(step.data.dptr), bias_correction, epsilon,
epsilon, reinterpret_cast<float *>(lr.data.dptr), (adamMode_t)mode, reinterpret_cast<float *>(lr.data.dptr), (adamMode_t)mode, weight_decay,
weight_decay, reinterpret_cast<float *>(inv_scale.data.dptr));) reinterpret_cast<float *>(inv_scale.data.dptr));)
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
} }
...@@ -932,28 +928,28 @@ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETenso ...@@ -932,28 +928,28 @@ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETenso
const float lr, const float beta1, const float beta2, const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode, const float epsilon, const int step, const int mode,
const int bias_correction, const float weight_decay, const int bias_correction, const float weight_decay,
const int device_id, cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_adam_cuda); NVTE_API_CALL(nvte_multi_tensor_adam_cuda);
using namespace transformer_engine; using namespace transformer_engine;
multi_tensor_adam::multi_tensor_adam_cuda( multi_tensor_adam::multi_tensor_adam_cuda(
chunk_size, *convertNVTETensorCheck(noop_flag), chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2, convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2,
epsilon, step, mode, bias_correction, weight_decay, device_id, stream); epsilon, step, mode, bias_correction, weight_decay, stream);
} }
void nvte_multi_tensor_adam_param_remainder_cuda( void nvte_multi_tensor_adam_param_remainder_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, const float lr, const float beta1, const float beta2, const size_t num_tensors_per_list, const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode, const int bias_correction, const float epsilon, const int step, const int mode, const int bias_correction,
const float weight_decay, const int device_id, cudaStream_t stream) { const float weight_decay, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_adam_param_remainder_cuda); NVTE_API_CALL(nvte_multi_tensor_adam_param_remainder_cuda);
using namespace transformer_engine; using namespace transformer_engine;
multi_tensor_adam::multi_tensor_adam_param_remainder_cuda( multi_tensor_adam::multi_tensor_adam_param_remainder_cuda(
chunk_size, *convertNVTETensorCheck(noop_flag), chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2, convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2,
epsilon, step, mode, bias_correction, weight_decay, device_id, stream); epsilon, step, mode, bias_correction, weight_decay, stream);
} }
void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag, void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag,
...@@ -962,22 +958,21 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag, ...@@ -962,22 +958,21 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag,
const float beta1, const float beta2, const float epsilon, const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction, const int step, const int mode, const int bias_correction,
const float weight_decay, const NVTEDType fp8_dtype, const float weight_decay, const NVTEDType fp8_dtype,
const int device_id, cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_adam_fp8_cuda); NVTE_API_CALL(nvte_multi_tensor_adam_fp8_cuda);
using namespace transformer_engine; using namespace transformer_engine;
multi_tensor_adam::multi_tensor_adam_fp8_cuda( multi_tensor_adam::multi_tensor_adam_fp8_cuda(
chunk_size, *convertNVTETensorCheck(noop_flag), chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2, convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2,
epsilon, step, mode, bias_correction, weight_decay, static_cast<DType>(fp8_dtype), device_id, epsilon, step, mode, bias_correction, weight_decay, static_cast<DType>(fp8_dtype), stream);
stream);
} }
void nvte_multi_tensor_adam_capturable_cuda( void nvte_multi_tensor_adam_capturable_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2, const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2,
const float epsilon, NVTETensor step, const int mode, const int bias_correction, const float epsilon, NVTETensor step, const int mode, const int bias_correction,
const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream) { const float weight_decay, NVTETensor inv_scale, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_adam_capturable_cuda); NVTE_API_CALL(nvte_multi_tensor_adam_capturable_cuda);
using namespace transformer_engine; using namespace transformer_engine;
...@@ -985,14 +980,14 @@ void nvte_multi_tensor_adam_capturable_cuda( ...@@ -985,14 +980,14 @@ void nvte_multi_tensor_adam_capturable_cuda(
chunk_size, *convertNVTETensorCheck(noop_flag), chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
*convertNVTETensorCheck(lr), beta1, beta2, epsilon, *convertNVTETensorCheck(step), mode, *convertNVTETensorCheck(lr), beta1, beta2, epsilon, *convertNVTETensorCheck(step), mode,
bias_correction, weight_decay, *convertNVTETensorCheck(inv_scale), device_id, stream); bias_correction, weight_decay, *convertNVTETensorCheck(inv_scale), stream);
} }
void nvte_multi_tensor_adam_capturable_master_cuda( void nvte_multi_tensor_adam_capturable_master_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2, const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2,
const float epsilon, NVTETensor step, const int mode, const int bias_correction, const float epsilon, NVTETensor step, const int mode, const int bias_correction,
const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream) { const float weight_decay, NVTETensor inv_scale, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_adam_capturable_master_cuda); NVTE_API_CALL(nvte_multi_tensor_adam_capturable_master_cuda);
using namespace transformer_engine; using namespace transformer_engine;
...@@ -1000,5 +995,5 @@ void nvte_multi_tensor_adam_capturable_master_cuda( ...@@ -1000,5 +995,5 @@ void nvte_multi_tensor_adam_capturable_master_cuda(
chunk_size, *convertNVTETensorCheck(noop_flag), chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
*convertNVTETensorCheck(lr), beta1, beta2, epsilon, *convertNVTETensorCheck(step), mode, *convertNVTETensorCheck(lr), beta1, beta2, epsilon, *convertNVTETensorCheck(step), mode,
bias_correction, weight_decay, *convertNVTETensorCheck(inv_scale), device_id, stream); bias_correction, weight_decay, *convertNVTETensorCheck(inv_scale), stream);
} }
...@@ -58,26 +58,27 @@ struct ComputeScaleAndScaleInvFunctor { ...@@ -58,26 +58,27 @@ struct ComputeScaleAndScaleInvFunctor {
void multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, Tensor noop_flag, void multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor *>> tensor_lists, std::vector<std::vector<Tensor *>> tensor_lists,
float max_fp8, bool force_pow_2_scales, float max_fp8, bool force_pow_2_scales,
float epsilon, const int device_id, float epsilon, cudaStream_t stream) {
cudaStream_t stream) {
multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
ComputeScaleAndScaleInvFunctor(), device_id, stream, max_fp8, ComputeScaleAndScaleInvFunctor(), stream, max_fp8, force_pow_2_scales,
force_pow_2_scales, epsilon); epsilon);
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
} }
} // namespace multi_tensor_compute_scale } // namespace multi_tensor_compute_scale
} // namespace transformer_engine } // namespace transformer_engine
void nvte_multi_tensor_compute_scale_and_scale_inv_cuda( void nvte_multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, NVTETensor noop_flag,
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, NVTETensor **tensor_lists,
const size_t num_tensors_per_list, float max_fp8, int force_pow_2_scales, float epsilon, const size_t num_tensor_lists,
const int device_id, cudaStream_t stream) { const size_t num_tensors_per_list,
float max_fp8, int force_pow_2_scales,
float epsilon, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_compute_scale_and_scale_inv_cuda); NVTE_API_CALL(nvte_multi_tensor_compute_scale_and_scale_inv_cuda);
using namespace transformer_engine; using namespace transformer_engine;
multi_tensor_compute_scale::multi_tensor_compute_scale_and_scale_inv_cuda( multi_tensor_compute_scale::multi_tensor_compute_scale_and_scale_inv_cuda(
chunk_size, *convertNVTETensorCheck(noop_flag), chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), max_fp8, convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), max_fp8,
force_pow_2_scales, epsilon, device_id, stream); force_pow_2_scales, epsilon, stream);
} }
...@@ -393,13 +393,12 @@ __global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret, ...@@ -393,13 +393,12 @@ __global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret,
void multi_tensor_l2norm_cuda(int chunk_size, Tensor noop_flag, void multi_tensor_l2norm_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor *>> tensor_lists, Tensor output, std::vector<std::vector<Tensor *>> tensor_lists, Tensor output,
Tensor output_per_tensor, Tensor ret, Tensor ret_per_tensor, Tensor output_per_tensor, Tensor ret, Tensor ret_per_tensor,
bool per_tensor, int max_chunks_per_tensor, const int device_id, bool per_tensor, int max_chunks_per_tensor, cudaStream_t stream) {
cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[0][0]->dtype(), dtype, tensor_lists[0][0]->dtype(), dtype,
multi_tensor_apply<1>( multi_tensor_apply<1>(
BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, L2NormFunctor<dtype>(), device_id, BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, L2NormFunctor<dtype>(), stream,
stream, reinterpret_cast<float *>(output.data.dptr), reinterpret_cast<float *>(output.data.dptr),
per_tensor ? reinterpret_cast<float *>(output_per_tensor.data.dptr) : nullptr, per_tensor, per_tensor ? reinterpret_cast<float *>(output_per_tensor.data.dptr) : nullptr, per_tensor,
max_chunks_per_tensor);) max_chunks_per_tensor);)
...@@ -408,7 +407,6 @@ void multi_tensor_l2norm_cuda(int chunk_size, Tensor noop_flag, ...@@ -408,7 +407,6 @@ void multi_tensor_l2norm_cuda(int chunk_size, Tensor noop_flag,
// This involves one more small kernel launches, but will be negligible end to end. // This involves one more small kernel launches, but will be negligible end to end.
// I could get rid of these by hacking the functor + multi tensor harness with persistence // I could get rid of these by hacking the functor + multi tensor harness with persistence
// logic, but keeping it simple for now // logic, but keeping it simple for now
const OptionalCUDAGuard device_guard(device_id);
cleanup<<<per_tensor ? tensor_lists[0].size() : 1, 512, 0, stream>>>( cleanup<<<per_tensor ? tensor_lists[0].size() : 1, 512, 0, stream>>>(
reinterpret_cast<float *>(output.data.dptr), reinterpret_cast<float *>(output.data.dptr),
per_tensor ? reinterpret_cast<float *>(output_per_tensor.data.dptr) : nullptr, per_tensor ? reinterpret_cast<float *>(output_per_tensor.data.dptr) : nullptr,
...@@ -421,13 +419,12 @@ void multi_tensor_unscale_l2norm_cuda(int chunk_size, Tensor noop_flag, ...@@ -421,13 +419,12 @@ void multi_tensor_unscale_l2norm_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor *>> tensor_lists, std::vector<std::vector<Tensor *>> tensor_lists,
Tensor output, Tensor output_per_tensor, Tensor ret, Tensor output, Tensor output_per_tensor, Tensor ret,
Tensor ret_per_tensor, Tensor inv_scale, bool per_tensor, Tensor ret_per_tensor, Tensor inv_scale, bool per_tensor,
int max_chunks_per_tensor, const int device_id, int max_chunks_per_tensor, cudaStream_t stream) {
cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[0][0]->dtype(), dtype, tensor_lists[0][0]->dtype(), dtype,
multi_tensor_apply<1>( multi_tensor_apply<1>(
BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, UnscaleL2NormFunctor<dtype>(), device_id, BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, UnscaleL2NormFunctor<dtype>(), stream,
stream, reinterpret_cast<float *>(inv_scale.data.dptr), reinterpret_cast<float *>(inv_scale.data.dptr),
reinterpret_cast<float *>(output.data.dptr), reinterpret_cast<float *>(output.data.dptr),
per_tensor ? reinterpret_cast<float *>(output_per_tensor.data.dptr) : nullptr, per_tensor, per_tensor ? reinterpret_cast<float *>(output_per_tensor.data.dptr) : nullptr, per_tensor,
max_chunks_per_tensor);) max_chunks_per_tensor);)
...@@ -437,7 +434,6 @@ void multi_tensor_unscale_l2norm_cuda(int chunk_size, Tensor noop_flag, ...@@ -437,7 +434,6 @@ void multi_tensor_unscale_l2norm_cuda(int chunk_size, Tensor noop_flag,
// This involves one more small kernel launches, but will be negligible end to end. // This involves one more small kernel launches, but will be negligible end to end.
// I could get rid of these by hacking the functor + multi tensor harness with persistence // I could get rid of these by hacking the functor + multi tensor harness with persistence
// logic, but keeping it simple for now // logic, but keeping it simple for now
const OptionalCUDAGuard device_guard(device_id);
cleanup<<<per_tensor ? tensor_lists[0].size() : 1, 512, 0, stream>>>( cleanup<<<per_tensor ? tensor_lists[0].size() : 1, 512, 0, stream>>>(
reinterpret_cast<float *>(output.data.dptr), reinterpret_cast<float *>(output.data.dptr),
per_tensor ? reinterpret_cast<float *>(output_per_tensor.data.dptr) : nullptr, per_tensor ? reinterpret_cast<float *>(output_per_tensor.data.dptr) : nullptr,
...@@ -453,8 +449,7 @@ void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETen ...@@ -453,8 +449,7 @@ void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETen
const size_t num_tensor_lists, const size_t num_tensors_per_list, const size_t num_tensor_lists, const size_t num_tensors_per_list,
NVTETensor output, NVTETensor output_per_tensor, NVTETensor ret, NVTETensor output, NVTETensor output_per_tensor, NVTETensor ret,
NVTETensor ret_per_tensor, int per_tensor, NVTETensor ret_per_tensor, int per_tensor,
int max_chunks_per_tensor, const int device_id, int max_chunks_per_tensor, cudaStream_t stream) {
cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_l2norm_cuda); NVTE_API_CALL(nvte_multi_tensor_l2norm_cuda);
using namespace transformer_engine; using namespace transformer_engine;
...@@ -463,7 +458,7 @@ void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETen ...@@ -463,7 +458,7 @@ void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETen
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
*convertNVTETensorCheck(output), *convertNVTETensorCheck(output_per_tensor), *convertNVTETensorCheck(output), *convertNVTETensorCheck(output_per_tensor),
*convertNVTETensorCheck(ret), *convertNVTETensorCheck(ret_per_tensor), per_tensor, *convertNVTETensorCheck(ret), *convertNVTETensorCheck(ret_per_tensor), per_tensor,
max_chunks_per_tensor, device_id, stream); max_chunks_per_tensor, stream);
} }
void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag, void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
...@@ -472,7 +467,7 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag, ...@@ -472,7 +467,7 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
NVTETensor output_per_tensor, NVTETensor ret, NVTETensor output_per_tensor, NVTETensor ret,
NVTETensor ret_per_tensor, NVTETensor inv_scale, NVTETensor ret_per_tensor, NVTETensor inv_scale,
int per_tensor, int max_chunks_per_tensor, int per_tensor, int max_chunks_per_tensor,
const int device_id, cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_unscale_l2norm_cuda); NVTE_API_CALL(nvte_multi_tensor_unscale_l2norm_cuda);
using namespace transformer_engine; using namespace transformer_engine;
...@@ -481,5 +476,5 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag, ...@@ -481,5 +476,5 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
*convertNVTETensorCheck(output), *convertNVTETensorCheck(output_per_tensor), *convertNVTETensorCheck(output), *convertNVTETensorCheck(output_per_tensor),
*convertNVTETensorCheck(ret), *convertNVTETensorCheck(ret_per_tensor), *convertNVTETensorCheck(ret), *convertNVTETensorCheck(ret_per_tensor),
*convertNVTETensorCheck(inv_scale), per_tensor, max_chunks_per_tensor, device_id, stream); *convertNVTETensorCheck(inv_scale), per_tensor, max_chunks_per_tensor, stream);
} }
...@@ -14,53 +14,6 @@ ...@@ -14,53 +14,6 @@
// This header is the one-stop shop for all your multi-tensor apply needs. // This header is the one-stop shop for all your multi-tensor apply needs.
// Change device if needed.
class OptionalCUDAGuard {
public:
explicit OptionalCUDAGuard(int new_device) {
if (new_device < 0) return;
int current_device;
NVTE_CHECK_CUDA(cudaGetDevice(&current_device));
if (new_device != current_device) {
NVTE_CHECK_CUDA(cudaSetDevice(new_device));
device_changed_ = true;
prev_device_ = current_device;
}
}
OptionalCUDAGuard(const OptionalCUDAGuard &) = delete;
OptionalCUDAGuard &operator=(const OptionalCUDAGuard &) = delete;
OptionalCUDAGuard(OptionalCUDAGuard &&other) noexcept
: prev_device_(other.prev_device_), device_changed_(other.device_changed_) {
other.device_changed_ = false;
}
OptionalCUDAGuard &operator=(OptionalCUDAGuard &&other) noexcept {
if (this != &other) {
if (device_changed_) {
cudaSetDevice(prev_device_);
}
prev_device_ = other.prev_device_;
device_changed_ = other.device_changed_;
other.device_changed_ = false;
}
return *this;
}
~OptionalCUDAGuard() {
if (device_changed_) {
cudaSetDevice(prev_device_);
}
}
private:
int prev_device_;
bool device_changed_ = false;
};
// TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson) // TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson)
constexpr int depth_to_max_tensors[6] = {110, 64, 48, 36, 30, 24}; constexpr int depth_to_max_tensors[6] = {110, 64, 48, 36, 30, 24};
constexpr int depth_to_max_blocks[6] = {320, 320, 320, 320, 320, 320}; constexpr int depth_to_max_blocks[6] = {320, 320, 320, 320, 320, 320};
...@@ -94,7 +47,7 @@ template <int depth, bool USE_FP8 = false, typename T, typename... ArgTypes> ...@@ -94,7 +47,7 @@ template <int depth, bool USE_FP8 = false, typename T, typename... ArgTypes>
void multi_tensor_apply(int64_t block_size, int64_t chunk_size, void multi_tensor_apply(int64_t block_size, int64_t chunk_size,
const transformer_engine::Tensor &noop_flag, const transformer_engine::Tensor &noop_flag,
std::vector<std::vector<transformer_engine::Tensor *>> tensor_lists, std::vector<std::vector<transformer_engine::Tensor *>> tensor_lists,
T callable, const int device_id, cudaStream_t stream, ArgTypes... args) { T callable, cudaStream_t stream, ArgTypes... args) {
const size_t num_tensor_lists = tensor_lists.size(); const size_t num_tensor_lists = tensor_lists.size();
const size_t num_tensors_per_list = tensor_lists[0].size(); const size_t num_tensors_per_list = tensor_lists[0].size();
...@@ -108,8 +61,6 @@ void multi_tensor_apply(int64_t block_size, int64_t chunk_size, ...@@ -108,8 +61,6 @@ void multi_tensor_apply(int64_t block_size, int64_t chunk_size,
TensorListMetadata<depth, USE_FP8> tl; TensorListMetadata<depth, USE_FP8> tl;
const OptionalCUDAGuard device_guard(device_id);
tl.start_tensor_this_launch = 0; tl.start_tensor_this_launch = 0;
int loc_block_info = 0; int loc_block_info = 0;
int loc_tensor_info = 0; int loc_tensor_info = 0;
......
...@@ -104,13 +104,13 @@ struct ScaleFunctor { ...@@ -104,13 +104,13 @@ struct ScaleFunctor {
void multi_tensor_scale_cuda(int chunk_size, Tensor noop_flag, void multi_tensor_scale_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor *>> tensor_lists, float scale, std::vector<std::vector<Tensor *>> tensor_lists, float scale,
const int device_id, cudaStream_t stream) { cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[0][0]->dtype(), p_in_type, tensor_lists[0][0]->dtype(), p_in_type,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[1][0]->dtype(), g_in_type, tensor_lists[1][0]->dtype(), g_in_type,
multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
ScaleFunctor<p_in_type, g_in_type>(), device_id, stream, scale);)) ScaleFunctor<p_in_type, g_in_type>(), stream, scale);))
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
} }
...@@ -119,12 +119,11 @@ void multi_tensor_scale_cuda(int chunk_size, Tensor noop_flag, ...@@ -119,12 +119,11 @@ void multi_tensor_scale_cuda(int chunk_size, Tensor noop_flag,
void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list, const size_t num_tensor_lists, const size_t num_tensors_per_list,
float scale, const int device_id, cudaStream_t stream) { float scale, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_scale_cuda); NVTE_API_CALL(nvte_multi_tensor_scale_cuda);
using namespace transformer_engine; using namespace transformer_engine;
multi_tensor_scale::multi_tensor_scale_cuda( multi_tensor_scale::multi_tensor_scale_cuda(
chunk_size, *convertNVTETensorCheck(noop_flag), chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), scale, device_id, convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), scale, stream);
stream);
} }
...@@ -127,8 +127,7 @@ struct SGDFunctor { ...@@ -127,8 +127,7 @@ struct SGDFunctor {
void multi_tensor_sgd_cuda(int chunk_size, Tensor noop_flag, void multi_tensor_sgd_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor*>> tensor_lists, float wd, float momentum, std::vector<std::vector<Tensor*>> tensor_lists, float wd, float momentum,
float dampening, float lr, bool nesterov, bool first_run, float dampening, float lr, bool nesterov, bool first_run,
bool wd_after_momentum, float scale, const int device_id, bool wd_after_momentum, float scale, cudaStream_t stream) {
cudaStream_t stream) {
const size_t num_tensor_lists = tensor_lists.size(); const size_t num_tensor_lists = tensor_lists.size();
const size_t num_tensors_per_list = tensor_lists[0].size(); const size_t num_tensors_per_list = tensor_lists[0].size();
...@@ -154,29 +153,29 @@ void multi_tensor_sgd_cuda(int chunk_size, Tensor noop_flag, ...@@ -154,29 +153,29 @@ void multi_tensor_sgd_cuda(int chunk_size, Tensor noop_flag,
// Case 1. fp16, fp16, fp16, No // Case 1. fp16, fp16, fp16, No
if (grad_type == DType::kFloat16 && weight_type == DType::kFloat16 && num_tensor_lists == 3) { if (grad_type == DType::kFloat16 && weight_type == DType::kFloat16 && num_tensor_lists == 3) {
multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
SGDFunctor<3, fp16, fp16>(), device_id, stream, wd, momentum, dampening, SGDFunctor<3, fp16, fp16>(), stream, wd, momentum, dampening, lr,
lr, nesterov, first_run, wd_after_momentum, scale); nesterov, first_run, wd_after_momentum, scale);
} }
// Case 2. fp32, fp32, fp32, No // Case 2. fp32, fp32, fp32, No
else if (grad_type == DType::kFloat32 && // NOLINT(*) else if (grad_type == DType::kFloat32 && // NOLINT(*)
weight_type == DType::kFloat32 && num_tensor_lists == 3) { weight_type == DType::kFloat32 && num_tensor_lists == 3) {
multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
SGDFunctor<3, float, float>(), device_id, stream, wd, momentum, dampening, SGDFunctor<3, float, float>(), stream, wd, momentum, dampening, lr,
lr, nesterov, first_run, wd_after_momentum, scale); nesterov, first_run, wd_after_momentum, scale);
} }
// Case 3. fp16, fp32, fp32, Yes // Case 3. fp16, fp32, fp32, Yes
else if (grad_type == DType::kFloat16 && // NOLINT(*) else if (grad_type == DType::kFloat16 && // NOLINT(*)
weight_type == DType::kFloat32 && num_tensor_lists == 4) { weight_type == DType::kFloat32 && num_tensor_lists == 4) {
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
SGDFunctor<4, fp16, float>(), device_id, stream, wd, momentum, dampening, SGDFunctor<4, fp16, float>(), stream, wd, momentum, dampening, lr,
lr, nesterov, first_run, wd_after_momentum, scale); nesterov, first_run, wd_after_momentum, scale);
} }
// Case 4. fp32, fp32, fp32, Yes // Case 4. fp32, fp32, fp32, Yes
else if (grad_type == DType::kFloat32 && // NOLINT(*) else if (grad_type == DType::kFloat32 && // NOLINT(*)
weight_type == DType::kFloat32 && num_tensor_lists == 4) { weight_type == DType::kFloat32 && num_tensor_lists == 4) {
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
SGDFunctor<4, float, float>(), device_id, stream, wd, momentum, dampening, SGDFunctor<4, float, float>(), stream, wd, momentum, dampening, lr,
lr, nesterov, first_run, wd_after_momentum, scale); nesterov, first_run, wd_after_momentum, scale);
} else { } else {
NVTE_ERROR("Unsupported combination of weight and gradient types."); NVTE_ERROR("Unsupported combination of weight and gradient types.");
} }
...@@ -191,12 +190,12 @@ void nvte_multi_tensor_sgd_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor ...@@ -191,12 +190,12 @@ void nvte_multi_tensor_sgd_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor
const size_t num_tensor_lists, const size_t num_tensors_per_list, const size_t num_tensor_lists, const size_t num_tensors_per_list,
float wd, float momentum, float dampening, float lr, int nesterov, float wd, float momentum, float dampening, float lr, int nesterov,
int first_run, int wd_after_momentum, float scale, int first_run, int wd_after_momentum, float scale,
const int device_id, cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_sgd_cuda); NVTE_API_CALL(nvte_multi_tensor_sgd_cuda);
using namespace transformer_engine; using namespace transformer_engine;
multi_tensor_sgd::multi_tensor_sgd_cuda( multi_tensor_sgd::multi_tensor_sgd_cuda(
chunk_size, *convertNVTETensorCheck(noop_flag), chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), wd, momentum, convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), wd, momentum,
dampening, lr, nesterov, first_run, wd_after_momentum, scale, device_id, stream); dampening, lr, nesterov, first_run, wd_after_momentum, scale, stream);
} }
...@@ -16,11 +16,10 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, ...@@ -16,11 +16,10 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists); makeTransformerEngineTensorList(tensor_lists);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_adam_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, nvte_multi_tensor_adam_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists,
num_tensors, lr, beta1, beta2, epsilon, step, mode, bias_correction, num_tensors, lr, beta1, beta2, epsilon, step, mode, bias_correction,
weight_decay, device_id, at::cuda::getCurrentCUDAStream()); weight_decay, at::cuda::getCurrentCUDAStream());
} }
void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag, void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag,
...@@ -31,12 +30,10 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag ...@@ -31,12 +30,10 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists); makeTransformerEngineTensorList(tensor_lists);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_adam_param_remainder_cuda( nvte_multi_tensor_adam_param_remainder_cuda(
chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, lr, beta1, chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, lr, beta1,
beta2, epsilon, step, mode, bias_correction, weight_decay, device_id, beta2, epsilon, step, mode, bias_correction, weight_decay, at::cuda::getCurrentCUDAStream());
at::cuda::getCurrentCUDAStream());
} }
void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag, void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag,
...@@ -47,12 +44,11 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag, ...@@ -47,12 +44,11 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag,
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists); makeTransformerEngineTensorList(tensor_lists);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_adam_fp8_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), nvte_multi_tensor_adam_fp8_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(),
num_lists, num_tensors, lr, beta1, beta2, epsilon, step, mode, num_lists, num_tensors, lr, beta1, beta2, epsilon, step, mode,
bias_correction, weight_decay, static_cast<NVTEDType>(fp8_dtype), bias_correction, weight_decay, static_cast<NVTEDType>(fp8_dtype),
device_id, at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
} }
void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag, void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag,
...@@ -67,12 +63,11 @@ void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag, ...@@ -67,12 +63,11 @@ void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag,
auto lr_cu = makeTransformerEngineTensor(lr); auto lr_cu = makeTransformerEngineTensor(lr);
auto step_cu = makeTransformerEngineTensor(step); auto step_cu = makeTransformerEngineTensor(step);
auto inv_scale_cu = makeTransformerEngineTensor(inv_scale); auto inv_scale_cu = makeTransformerEngineTensor(inv_scale);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_adam_capturable_cuda( nvte_multi_tensor_adam_capturable_cuda(
chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors,
lr_cu.data(), beta1, beta2, epsilon, step_cu.data(), mode, bias_correction, weight_decay, lr_cu.data(), beta1, beta2, epsilon, step_cu.data(), mode, bias_correction, weight_decay,
inv_scale_cu.data(), device_id, at::cuda::getCurrentCUDAStream()); inv_scale_cu.data(), at::cuda::getCurrentCUDAStream());
} }
void multi_tensor_adam_capturable_master_cuda(int chunk_size, at::Tensor noop_flag, void multi_tensor_adam_capturable_master_cuda(int chunk_size, at::Tensor noop_flag,
...@@ -87,12 +82,11 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, at::Tensor noop_fl ...@@ -87,12 +82,11 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, at::Tensor noop_fl
auto lr_cu = makeTransformerEngineTensor(lr); auto lr_cu = makeTransformerEngineTensor(lr);
auto step_cu = makeTransformerEngineTensor(step); auto step_cu = makeTransformerEngineTensor(step);
auto inv_scale_cu = makeTransformerEngineTensor(inv_scale); auto inv_scale_cu = makeTransformerEngineTensor(inv_scale);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_adam_capturable_master_cuda( nvte_multi_tensor_adam_capturable_master_cuda(
chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors,
lr_cu.data(), beta1, beta2, epsilon, step_cu.data(), mode, bias_correction, weight_decay, lr_cu.data(), beta1, beta2, epsilon, step_cu.data(), mode, bias_correction, weight_decay,
inv_scale_cu.data(), device_id, at::cuda::getCurrentCUDAStream()); inv_scale_cu.data(), at::cuda::getCurrentCUDAStream());
} }
} // namespace transformer_engine::pytorch } // namespace transformer_engine::pytorch
...@@ -14,11 +14,10 @@ void multi_tensor_compute_scale_and_scale_inv_cuda( ...@@ -14,11 +14,10 @@ void multi_tensor_compute_scale_and_scale_inv_cuda(
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists); makeTransformerEngineTensorList(tensor_lists);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_compute_scale_and_scale_inv_cuda( nvte_multi_tensor_compute_scale_and_scale_inv_cuda(
chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, max_fp8, chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, max_fp8,
force_pow_2_scales, epsilon, device_id, at::cuda::getCurrentCUDAStream()); force_pow_2_scales, epsilon, at::cuda::getCurrentCUDAStream());
} }
} // namespace transformer_engine::pytorch } // namespace transformer_engine::pytorch
...@@ -43,12 +43,11 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda( ...@@ -43,12 +43,11 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
auto output_per_tensor_cu = makeTransformerEngineTensor(output_per_tensor); auto output_per_tensor_cu = makeTransformerEngineTensor(output_per_tensor);
auto ret_cu = makeTransformerEngineTensor(ret); auto ret_cu = makeTransformerEngineTensor(ret);
auto ret_per_tensor_cu = makeTransformerEngineTensor(ret_per_tensor); auto ret_per_tensor_cu = makeTransformerEngineTensor(ret_per_tensor);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_l2norm_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, nvte_multi_tensor_l2norm_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists,
num_tensors, output_cu.data(), output_per_tensor_cu.data(), num_tensors, output_cu.data(), output_per_tensor_cu.data(),
ret_cu.data(), ret_per_tensor_cu.data(), per_tensor, ret_cu.data(), ret_per_tensor_cu.data(), per_tensor,
max_chunks_per_tensor, device_id, at::cuda::getCurrentCUDAStream()); max_chunks_per_tensor, at::cuda::getCurrentCUDAStream());
return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor); return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor);
} }
...@@ -91,13 +90,11 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda( ...@@ -91,13 +90,11 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda(
auto ret_cu = makeTransformerEngineTensor(ret); auto ret_cu = makeTransformerEngineTensor(ret);
auto ret_per_tensor_cu = makeTransformerEngineTensor(ret_per_tensor); auto ret_per_tensor_cu = makeTransformerEngineTensor(ret_per_tensor);
auto inv_scale_cu = makeTransformerEngineTensor(inv_scale); auto inv_scale_cu = makeTransformerEngineTensor(inv_scale);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_unscale_l2norm_cuda( nvte_multi_tensor_unscale_l2norm_cuda(
chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors,
output_cu.data(), output_per_tensor_cu.data(), ret_cu.data(), ret_per_tensor_cu.data(), output_cu.data(), output_per_tensor_cu.data(), ret_cu.data(), ret_per_tensor_cu.data(),
inv_scale_cu.data(), per_tensor, max_chunks_per_tensor, device_id, inv_scale_cu.data(), per_tensor, max_chunks_per_tensor, at::cuda::getCurrentCUDAStream());
at::cuda::getCurrentCUDAStream());
return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor); return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor);
} }
......
...@@ -13,10 +13,9 @@ void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag, ...@@ -13,10 +13,9 @@ void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag,
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists); makeTransformerEngineTensorList(tensor_lists);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_scale_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, nvte_multi_tensor_scale_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists,
num_tensors, scale, device_id, at::cuda::getCurrentCUDAStream()); num_tensors, scale, at::cuda::getCurrentCUDAStream());
} }
} // namespace transformer_engine::pytorch } // namespace transformer_engine::pytorch
...@@ -15,11 +15,10 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag, ...@@ -15,11 +15,10 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists); makeTransformerEngineTensorList(tensor_lists);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_sgd_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, nvte_multi_tensor_sgd_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists,
num_tensors, wd, momentum, dampening, lr, nesterov, first_run, num_tensors, wd, momentum, dampening, lr, nesterov, first_run,
wd_after_momentum, scale, device_id, at::cuda::getCurrentCUDAStream()); wd_after_momentum, scale, at::cuda::getCurrentCUDAStream());
} }
} // namespace transformer_engine::pytorch } // namespace transformer_engine::pytorch
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment