Unverified Commit 1470116e authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[C][PyTorch] Remove deprecated `device_id` arg for multi tensor API (#1994)



* Remove deprecated device arg
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Remove test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 374849e3
......@@ -112,13 +112,6 @@ class TestFusedAdam(TestFusedOptimizer):
def test_bfloat16(self):
self.gen_single_type_test(param_type=torch.bfloat16, skip_assert=True)
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="more than 1 GPU required")
def test_multi_device(self):
devices = ("cuda:0", "cuda:1")
for current_dev, tensor_dev in product(devices, devices):
with torch.cuda.device(current_dev):
self.gen_single_type_test(param_type=torch.float, device=tensor_dev)
def test_multi_params(self):
sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]
......@@ -530,13 +523,6 @@ class TestFusedSGD(TestFusedOptimizer):
def test_half(self):
self.gen_single_type_test(param_type=torch.float16)
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="more than 1 GPU required")
def test_multi_device(self):
devices = ("cuda:0", "cuda:1")
for current_dev, tensor_dev in product(devices, devices):
with torch.cuda.device(current_dev):
self.gen_single_type_test(param_type=torch.float, device=tensor_dev)
class Model(torch.nn.Module):
def __init__(self):
......
......@@ -20,7 +20,6 @@ extern "C" {
/*! \brief Computes L2 norm for a list of tensors.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
......@@ -33,22 +32,19 @@ extern "C" {
* \param[out] ret_per_tensor L2 norm for each tensor.
* \param[in] per_tensor Whether to calculate per tensor or cumulative norm.
* \param[in] max_chunks_per_tensor Maximum number of chunks in any input tensor.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list,
NVTETensor output, NVTETensor output_per_tensor, NVTETensor ret,
NVTETensor ret_per_tensor, int per_tensor,
int max_chunks_per_tensor, const int device_id,
cudaStream_t stream);
int max_chunks_per_tensor, cudaStream_t stream);
/*! \brief Computes L2 norm for a list of tensors after unscaling.
*
* Unscaling is only done for computing the L2 norm. The tensors themselves are not updated.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
......@@ -62,7 +58,6 @@ void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETen
* \param[in] inv_scale Scalar for the unscaling operation.
* \param[in] per_tensor Whether to calculate per tensor or cumulative norm.
* \param[in] max_chunks_per_tensor Maximum number of chunks in any input tensor.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
......@@ -71,12 +66,11 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
NVTETensor output_per_tensor, NVTETensor ret,
NVTETensor ret_per_tensor, NVTETensor inv_scale,
int per_tensor, int max_chunks_per_tensor,
const int device_id, cudaStream_t stream);
cudaStream_t stream);
/*! \brief Compute and apply gradient update to parameters for Adam optimizer.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
......@@ -91,7 +85,6 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
* \param[in] mode Whether to use AdamW (L2 penalty applied to params).
* \param[in] bias_correction Whether to apply correction factor for moment estimates.
* \param[in] weight_decay L2 penalty for weight decay.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
......@@ -99,13 +92,12 @@ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETenso
const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode,
const int bias_correction, const float weight_decay,
const int device_id, cudaStream_t stream);
cudaStream_t stream);
/*! \brief Compute and apply gradient update to parameters for Adam optimizer
* where the master parameters only store the remainder bits.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
......@@ -120,20 +112,18 @@ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETenso
* \param[in] mode Whether to use AdamW (L2 penalty applied to params).
* \param[in] bias_correction Whether to apply correction factor for moment estimates.
* \param[in] weight_decay L2 penalty for weight decay.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_multi_tensor_adam_param_remainder_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode, const int bias_correction,
const float weight_decay, const int device_id, cudaStream_t stream);
const float weight_decay, cudaStream_t stream);
/*! \brief Compute and apply gradient update to parameters for Adam optimizer
* when model parameters are in Float8 precision.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
......@@ -149,7 +139,6 @@ void nvte_multi_tensor_adam_param_remainder_cuda(
* \param[in] bias_correction Whether to apply correction factor for moment estimates.
* \param[in] weight_decay L2 penalty for weight decay.
* \param[in] fp8_dtype FP8 data type for model parameters.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag,
......@@ -158,13 +147,12 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag,
const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction,
const float weight_decay, const NVTEDType fp8_dtype,
const int device_id, cudaStream_t stream);
cudaStream_t stream);
/*! \brief Compute and apply gradient update to parameters for Adam optimizer
* with CUDA graph support and LR scheduling.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
......@@ -180,20 +168,18 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag,
* \param[in] bias_correction Whether to apply correction factor for moment estimates.
* \param[in] weight_decay L2 penalty for weight decay.
* \param[in] inv_scale Scalar for the unscaling operation.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_multi_tensor_adam_capturable_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2,
const float epsilon, NVTETensor step, const int mode, const int bias_correction,
const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream);
const float weight_decay, NVTETensor inv_scale, cudaStream_t stream);
/*! \brief Compute and apply gradient update to parameters for Adam optimizer
* with CUDA graph support, LR scheduling, and FP32 master weights.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
......@@ -209,19 +195,17 @@ void nvte_multi_tensor_adam_capturable_cuda(
* \param[in] bias_correction Whether to apply correction factor for moment estimates.
* \param[in] weight_decay L2 penalty for weight decay.
* \param[in] inv_scale Scalar for the unscaling operation.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_multi_tensor_adam_capturable_master_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2,
const float epsilon, NVTETensor step, const int mode, const int bias_correction,
const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream);
const float weight_decay, NVTETensor inv_scale, cudaStream_t stream);
/*! \brief Compute and apply gradient update to parameters for SGD optimizer.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
......@@ -236,19 +220,17 @@ void nvte_multi_tensor_adam_capturable_master_cuda(
* \param[in] first_run Whether momentum buffers have been initialized.
* \param[in] wd_after_momentum Whether to applied weight decay after momentum update.
* \param[in] scale Scalar for the scaling operation.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_multi_tensor_sgd_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list,
float wd, float momentum, float dampening, float lr, int nesterov,
int first_run, int wd_after_momentum, float scale,
const int device_id, cudaStream_t stream);
cudaStream_t stream);
/*! \brief Check overflow and scale a list of tensors.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
......@@ -256,17 +238,15 @@ void nvte_multi_tensor_sgd_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor
* \param[in] num_tensor_lists Size (dim0) of tensor_lists.
* \param[in] num_tensors_per_list Size (dim1) of tensor_lists.
* \param[in] scale Scalar for the scaling operation.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list,
float scale, const int device_id, cudaStream_t stream);
float scale, cudaStream_t stream);
/*! \brief Check overflow and scale a list of tensors.
*
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
......@@ -276,13 +256,14 @@ void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETens
* \param[in] max_fp8 Maximum representible value in underlying FP8 format.
* \param[in] force_pow_2_scales Ensure scaling factors are a power of 2.
* \param[in] epsilon Term added to the denominator for numerical stability.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_multi_tensor_compute_scale_and_scale_inv_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, float max_fp8, int force_pow_2_scales, float epsilon,
const int device_id, cudaStream_t stream);
void nvte_multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, NVTETensor noop_flag,
NVTETensor **tensor_lists,
const size_t num_tensor_lists,
const size_t num_tensors_per_list,
float max_fp8, int force_pow_2_scales,
float epsilon, cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
......
......@@ -576,7 +576,7 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor *>> tensor_lists, const float lr,
const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction,
const float weight_decay, const int device_id, cudaStream_t stream) {
const float weight_decay, cudaStream_t stream) {
// Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1) {
......@@ -643,20 +643,20 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
g_in_type_te, g_in_type,
multi_tensor_apply<4>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag,
tensor_lists,
AdamFunctor<p_in_type, g_in_type, float, int64_t>(), device_id,
stream, beta1, beta2, bias_correction1, bias_correction2,
epsilon, lr, (adamMode_t)mode, weight_decay);));
AdamFunctor<p_in_type, g_in_type, float, int64_t>(), stream,
beta1, beta2, bias_correction1, bias_correction2, epsilon, lr,
(adamMode_t)mode, weight_decay);));
} else {
// g, p, m, v, p_master
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
p_in_type_te, p_in_type,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
g_in_type_te, g_in_type,
multi_tensor_apply<5>(
(int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists,
AdamFunctorMaster<p_in_type, g_in_type, float, int64_t>(), device_id, stream,
beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode,
weight_decay);));
multi_tensor_apply<5>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag,
tensor_lists,
AdamFunctorMaster<p_in_type, g_in_type, float, int64_t>(),
stream, beta1, beta2, bias_correction1, bias_correction2,
epsilon, lr, (adamMode_t)mode, weight_decay);));
}
} else {
if (num_tensor_lists == 4) {
......@@ -666,9 +666,9 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
g_in_type_te, g_in_type,
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamFunctor<p_in_type, g_in_type, float, int32_t>(), device_id,
stream, beta1, beta2, bias_correction1, bias_correction2,
epsilon, lr, (adamMode_t)mode, weight_decay);));
AdamFunctor<p_in_type, g_in_type, float, int32_t>(), stream,
beta1, beta2, bias_correction1, bias_correction2, epsilon, lr,
(adamMode_t)mode, weight_decay);));
} else {
// g, p, m, v, p_master
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
......@@ -677,9 +677,8 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
g_in_type_te, g_in_type,
multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamFunctorMaster<p_in_type, g_in_type, float, int32_t>(),
device_id, stream, beta1, beta2, bias_correction1,
bias_correction2, epsilon, lr, (adamMode_t)mode,
weight_decay);));
stream, beta1, beta2, bias_correction1, bias_correction2,
epsilon, lr, (adamMode_t)mode, weight_decay);));
}
}
NVTE_CHECK_CUDA(cudaGetLastError());
......@@ -690,7 +689,7 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, Tensor noop_flag,
const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode,
const int bias_correction, const float weight_decay,
const int device_id, cudaStream_t stream) {
cudaStream_t stream) {
// Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1) {
......@@ -732,8 +731,8 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, Tensor noop_flag,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
g_in_type_te, g_in_type,
multi_tensor_apply<5>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists,
AdamFunctorMasterParamRemainder<g_in_type, float, int64_t>(), device_id,
stream, beta1, beta2, bias_correction1, bias_correction2, epsilon, lr,
AdamFunctorMasterParamRemainder<g_in_type, float, int64_t>(), stream,
beta1, beta2, bias_correction1, bias_correction2, epsilon, lr,
(adamMode_t)mode, weight_decay););
NVTE_CHECK_CUDA(cudaGetLastError());
}
......@@ -743,7 +742,7 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag,
const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction,
const float weight_decay, const DType fp8_dtype,
const int device_id, cudaStream_t stream) {
cudaStream_t stream) {
// Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1) {
......@@ -813,9 +812,8 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag,
g_in_type_te, g_in_type,
multi_tensor_apply<5, true>(
(int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists,
AdamFunctorMaster<FP8_T, g_in_type, float, int64_t>(), device_id, stream, beta1,
beta2, bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode,
weight_decay);));
AdamFunctorMaster<FP8_T, g_in_type, float, int64_t>(), stream, beta1, beta2,
bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, weight_decay);));
} else {
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
fp8_dtype, FP8_T,
......@@ -823,9 +821,8 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag,
g_in_type_te, g_in_type,
multi_tensor_apply<5, true>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamFunctorMaster<FP8_T, g_in_type, float, int32_t>(),
device_id, stream, beta1, beta2, bias_correction1,
bias_correction2, epsilon, lr, (adamMode_t)mode,
weight_decay);));
stream, beta1, beta2, bias_correction1, bias_correction2,
epsilon, lr, (adamMode_t)mode, weight_decay);));
}
NVTE_CHECK_CUDA(cudaGetLastError());
}
......@@ -835,7 +832,7 @@ void multi_tensor_adam_capturable_cuda(int chunk_size, Tensor noop_flag,
const float beta1, const float beta2, const float epsilon,
Tensor step, const int mode, const int bias_correction,
const float weight_decay, Tensor inv_scale,
const int device_id, cudaStream_t stream) {
cudaStream_t stream) {
// Check tensor list sizes
// 4 tensor lists: g, p, m, v
const size_t num_tensor_lists = tensor_lists.size();
......@@ -867,7 +864,7 @@ void multi_tensor_adam_capturable_cuda(int chunk_size, Tensor noop_flag,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[0][0]->dtype(), dtype,
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamCapturableFunctor<dtype, float>(), device_id, stream, beta1, beta2,
AdamCapturableFunctor<dtype, float>(), stream, beta1, beta2,
reinterpret_cast<int *>(step.data.dptr), bias_correction, epsilon,
reinterpret_cast<float *>(lr.data.dptr), (adamMode_t)mode, weight_decay,
reinterpret_cast<float *>(inv_scale.data.dptr));)
......@@ -880,8 +877,7 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, Tensor noop_flag,
Tensor lr, const float beta1, const float beta2,
const float epsilon, Tensor step, const int mode,
const int bias_correction, const float weight_decay,
Tensor inv_scale, const int device_id,
cudaStream_t stream) {
Tensor inv_scale, cudaStream_t stream) {
// Check tensor list sizes
// 4 tensor lists: g, p, m, v, p_master
const size_t num_tensor_lists = tensor_lists.size();
......@@ -916,10 +912,10 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, Tensor noop_flag,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[0][0]->dtype(), dtype,
multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamCapturableMasterFunctor<dtype, float>(), device_id, stream, beta1,
beta2, reinterpret_cast<int *>(step.data.dptr), bias_correction,
epsilon, reinterpret_cast<float *>(lr.data.dptr), (adamMode_t)mode,
weight_decay, reinterpret_cast<float *>(inv_scale.data.dptr));)
AdamCapturableMasterFunctor<dtype, float>(), stream, beta1, beta2,
reinterpret_cast<int *>(step.data.dptr), bias_correction, epsilon,
reinterpret_cast<float *>(lr.data.dptr), (adamMode_t)mode, weight_decay,
reinterpret_cast<float *>(inv_scale.data.dptr));)
NVTE_CHECK_CUDA(cudaGetLastError());
}
......@@ -932,28 +928,28 @@ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETenso
const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode,
const int bias_correction, const float weight_decay,
const int device_id, cudaStream_t stream) {
cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_adam_cuda);
using namespace transformer_engine;
multi_tensor_adam::multi_tensor_adam_cuda(
chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2,
epsilon, step, mode, bias_correction, weight_decay, device_id, stream);
epsilon, step, mode, bias_correction, weight_decay, stream);
}
void nvte_multi_tensor_adam_param_remainder_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode, const int bias_correction,
const float weight_decay, const int device_id, cudaStream_t stream) {
const float weight_decay, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_adam_param_remainder_cuda);
using namespace transformer_engine;
multi_tensor_adam::multi_tensor_adam_param_remainder_cuda(
chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2,
epsilon, step, mode, bias_correction, weight_decay, device_id, stream);
epsilon, step, mode, bias_correction, weight_decay, stream);
}
void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag,
......@@ -962,22 +958,21 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag,
const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction,
const float weight_decay, const NVTEDType fp8_dtype,
const int device_id, cudaStream_t stream) {
cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_adam_fp8_cuda);
using namespace transformer_engine;
multi_tensor_adam::multi_tensor_adam_fp8_cuda(
chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2,
epsilon, step, mode, bias_correction, weight_decay, static_cast<DType>(fp8_dtype), device_id,
stream);
epsilon, step, mode, bias_correction, weight_decay, static_cast<DType>(fp8_dtype), stream);
}
void nvte_multi_tensor_adam_capturable_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2,
const float epsilon, NVTETensor step, const int mode, const int bias_correction,
const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream) {
const float weight_decay, NVTETensor inv_scale, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_adam_capturable_cuda);
using namespace transformer_engine;
......@@ -985,14 +980,14 @@ void nvte_multi_tensor_adam_capturable_cuda(
chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
*convertNVTETensorCheck(lr), beta1, beta2, epsilon, *convertNVTETensorCheck(step), mode,
bias_correction, weight_decay, *convertNVTETensorCheck(inv_scale), device_id, stream);
bias_correction, weight_decay, *convertNVTETensorCheck(inv_scale), stream);
}
void nvte_multi_tensor_adam_capturable_master_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2,
const float epsilon, NVTETensor step, const int mode, const int bias_correction,
const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream) {
const float weight_decay, NVTETensor inv_scale, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_adam_capturable_master_cuda);
using namespace transformer_engine;
......@@ -1000,5 +995,5 @@ void nvte_multi_tensor_adam_capturable_master_cuda(
chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
*convertNVTETensorCheck(lr), beta1, beta2, epsilon, *convertNVTETensorCheck(step), mode,
bias_correction, weight_decay, *convertNVTETensorCheck(inv_scale), device_id, stream);
bias_correction, weight_decay, *convertNVTETensorCheck(inv_scale), stream);
}
......@@ -58,26 +58,27 @@ struct ComputeScaleAndScaleInvFunctor {
void multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor *>> tensor_lists,
float max_fp8, bool force_pow_2_scales,
float epsilon, const int device_id,
cudaStream_t stream) {
float epsilon, cudaStream_t stream) {
multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
ComputeScaleAndScaleInvFunctor(), device_id, stream, max_fp8,
force_pow_2_scales, epsilon);
ComputeScaleAndScaleInvFunctor(), stream, max_fp8, force_pow_2_scales,
epsilon);
NVTE_CHECK_CUDA(cudaGetLastError());
}
} // namespace multi_tensor_compute_scale
} // namespace transformer_engine
void nvte_multi_tensor_compute_scale_and_scale_inv_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, float max_fp8, int force_pow_2_scales, float epsilon,
const int device_id, cudaStream_t stream) {
void nvte_multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, NVTETensor noop_flag,
NVTETensor **tensor_lists,
const size_t num_tensor_lists,
const size_t num_tensors_per_list,
float max_fp8, int force_pow_2_scales,
float epsilon, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_compute_scale_and_scale_inv_cuda);
using namespace transformer_engine;
multi_tensor_compute_scale::multi_tensor_compute_scale_and_scale_inv_cuda(
chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), max_fp8,
force_pow_2_scales, epsilon, device_id, stream);
force_pow_2_scales, epsilon, stream);
}
......@@ -393,13 +393,12 @@ __global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret,
void multi_tensor_l2norm_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor *>> tensor_lists, Tensor output,
Tensor output_per_tensor, Tensor ret, Tensor ret_per_tensor,
bool per_tensor, int max_chunks_per_tensor, const int device_id,
cudaStream_t stream) {
bool per_tensor, int max_chunks_per_tensor, cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[0][0]->dtype(), dtype,
multi_tensor_apply<1>(
BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, L2NormFunctor<dtype>(), device_id,
stream, reinterpret_cast<float *>(output.data.dptr),
BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, L2NormFunctor<dtype>(), stream,
reinterpret_cast<float *>(output.data.dptr),
per_tensor ? reinterpret_cast<float *>(output_per_tensor.data.dptr) : nullptr, per_tensor,
max_chunks_per_tensor);)
......@@ -408,7 +407,6 @@ void multi_tensor_l2norm_cuda(int chunk_size, Tensor noop_flag,
// This involves one more small kernel launches, but will be negligible end to end.
// I could get rid of these by hacking the functor + multi tensor harness with persistence
// logic, but keeping it simple for now
const OptionalCUDAGuard device_guard(device_id);
cleanup<<<per_tensor ? tensor_lists[0].size() : 1, 512, 0, stream>>>(
reinterpret_cast<float *>(output.data.dptr),
per_tensor ? reinterpret_cast<float *>(output_per_tensor.data.dptr) : nullptr,
......@@ -421,13 +419,12 @@ void multi_tensor_unscale_l2norm_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor *>> tensor_lists,
Tensor output, Tensor output_per_tensor, Tensor ret,
Tensor ret_per_tensor, Tensor inv_scale, bool per_tensor,
int max_chunks_per_tensor, const int device_id,
cudaStream_t stream) {
int max_chunks_per_tensor, cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[0][0]->dtype(), dtype,
multi_tensor_apply<1>(
BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, UnscaleL2NormFunctor<dtype>(), device_id,
stream, reinterpret_cast<float *>(inv_scale.data.dptr),
BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, UnscaleL2NormFunctor<dtype>(), stream,
reinterpret_cast<float *>(inv_scale.data.dptr),
reinterpret_cast<float *>(output.data.dptr),
per_tensor ? reinterpret_cast<float *>(output_per_tensor.data.dptr) : nullptr, per_tensor,
max_chunks_per_tensor);)
......@@ -437,7 +434,6 @@ void multi_tensor_unscale_l2norm_cuda(int chunk_size, Tensor noop_flag,
// This involves one more small kernel launches, but will be negligible end to end.
// I could get rid of these by hacking the functor + multi tensor harness with persistence
// logic, but keeping it simple for now
const OptionalCUDAGuard device_guard(device_id);
cleanup<<<per_tensor ? tensor_lists[0].size() : 1, 512, 0, stream>>>(
reinterpret_cast<float *>(output.data.dptr),
per_tensor ? reinterpret_cast<float *>(output_per_tensor.data.dptr) : nullptr,
......@@ -453,8 +449,7 @@ void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETen
const size_t num_tensor_lists, const size_t num_tensors_per_list,
NVTETensor output, NVTETensor output_per_tensor, NVTETensor ret,
NVTETensor ret_per_tensor, int per_tensor,
int max_chunks_per_tensor, const int device_id,
cudaStream_t stream) {
int max_chunks_per_tensor, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_l2norm_cuda);
using namespace transformer_engine;
......@@ -463,7 +458,7 @@ void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETen
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
*convertNVTETensorCheck(output), *convertNVTETensorCheck(output_per_tensor),
*convertNVTETensorCheck(ret), *convertNVTETensorCheck(ret_per_tensor), per_tensor,
max_chunks_per_tensor, device_id, stream);
max_chunks_per_tensor, stream);
}
void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
......@@ -472,7 +467,7 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
NVTETensor output_per_tensor, NVTETensor ret,
NVTETensor ret_per_tensor, NVTETensor inv_scale,
int per_tensor, int max_chunks_per_tensor,
const int device_id, cudaStream_t stream) {
cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_unscale_l2norm_cuda);
using namespace transformer_engine;
......@@ -481,5 +476,5 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
*convertNVTETensorCheck(output), *convertNVTETensorCheck(output_per_tensor),
*convertNVTETensorCheck(ret), *convertNVTETensorCheck(ret_per_tensor),
*convertNVTETensorCheck(inv_scale), per_tensor, max_chunks_per_tensor, device_id, stream);
*convertNVTETensorCheck(inv_scale), per_tensor, max_chunks_per_tensor, stream);
}
......@@ -14,53 +14,6 @@
// This header is the one-stop shop for all your multi-tensor apply needs.
// Change device if needed.
class OptionalCUDAGuard {
public:
explicit OptionalCUDAGuard(int new_device) {
if (new_device < 0) return;
int current_device;
NVTE_CHECK_CUDA(cudaGetDevice(&current_device));
if (new_device != current_device) {
NVTE_CHECK_CUDA(cudaSetDevice(new_device));
device_changed_ = true;
prev_device_ = current_device;
}
}
OptionalCUDAGuard(const OptionalCUDAGuard &) = delete;
OptionalCUDAGuard &operator=(const OptionalCUDAGuard &) = delete;
OptionalCUDAGuard(OptionalCUDAGuard &&other) noexcept
: prev_device_(other.prev_device_), device_changed_(other.device_changed_) {
other.device_changed_ = false;
}
OptionalCUDAGuard &operator=(OptionalCUDAGuard &&other) noexcept {
if (this != &other) {
if (device_changed_) {
cudaSetDevice(prev_device_);
}
prev_device_ = other.prev_device_;
device_changed_ = other.device_changed_;
other.device_changed_ = false;
}
return *this;
}
~OptionalCUDAGuard() {
if (device_changed_) {
cudaSetDevice(prev_device_);
}
}
private:
int prev_device_;
bool device_changed_ = false;
};
// TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson)
constexpr int depth_to_max_tensors[6] = {110, 64, 48, 36, 30, 24};
constexpr int depth_to_max_blocks[6] = {320, 320, 320, 320, 320, 320};
......@@ -94,7 +47,7 @@ template <int depth, bool USE_FP8 = false, typename T, typename... ArgTypes>
void multi_tensor_apply(int64_t block_size, int64_t chunk_size,
const transformer_engine::Tensor &noop_flag,
std::vector<std::vector<transformer_engine::Tensor *>> tensor_lists,
T callable, const int device_id, cudaStream_t stream, ArgTypes... args) {
T callable, cudaStream_t stream, ArgTypes... args) {
const size_t num_tensor_lists = tensor_lists.size();
const size_t num_tensors_per_list = tensor_lists[0].size();
......@@ -108,8 +61,6 @@ void multi_tensor_apply(int64_t block_size, int64_t chunk_size,
TensorListMetadata<depth, USE_FP8> tl;
const OptionalCUDAGuard device_guard(device_id);
tl.start_tensor_this_launch = 0;
int loc_block_info = 0;
int loc_tensor_info = 0;
......
......@@ -104,13 +104,13 @@ struct ScaleFunctor {
void multi_tensor_scale_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor *>> tensor_lists, float scale,
const int device_id, cudaStream_t stream) {
cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[0][0]->dtype(), p_in_type,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[1][0]->dtype(), g_in_type,
multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
ScaleFunctor<p_in_type, g_in_type>(), device_id, stream, scale);))
ScaleFunctor<p_in_type, g_in_type>(), stream, scale);))
NVTE_CHECK_CUDA(cudaGetLastError());
}
......@@ -119,12 +119,11 @@ void multi_tensor_scale_cuda(int chunk_size, Tensor noop_flag,
void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list,
float scale, const int device_id, cudaStream_t stream) {
float scale, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_scale_cuda);
using namespace transformer_engine;
multi_tensor_scale::multi_tensor_scale_cuda(
chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), scale, device_id,
stream);
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), scale, stream);
}
......@@ -127,8 +127,7 @@ struct SGDFunctor {
void multi_tensor_sgd_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor*>> tensor_lists, float wd, float momentum,
float dampening, float lr, bool nesterov, bool first_run,
bool wd_after_momentum, float scale, const int device_id,
cudaStream_t stream) {
bool wd_after_momentum, float scale, cudaStream_t stream) {
const size_t num_tensor_lists = tensor_lists.size();
const size_t num_tensors_per_list = tensor_lists[0].size();
......@@ -154,29 +153,29 @@ void multi_tensor_sgd_cuda(int chunk_size, Tensor noop_flag,
// Case 1. fp16, fp16, fp16, No
if (grad_type == DType::kFloat16 && weight_type == DType::kFloat16 && num_tensor_lists == 3) {
multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
SGDFunctor<3, fp16, fp16>(), device_id, stream, wd, momentum, dampening,
lr, nesterov, first_run, wd_after_momentum, scale);
SGDFunctor<3, fp16, fp16>(), stream, wd, momentum, dampening, lr,
nesterov, first_run, wd_after_momentum, scale);
}
// Case 2. fp32, fp32, fp32, No
else if (grad_type == DType::kFloat32 && // NOLINT(*)
weight_type == DType::kFloat32 && num_tensor_lists == 3) {
multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
SGDFunctor<3, float, float>(), device_id, stream, wd, momentum, dampening,
lr, nesterov, first_run, wd_after_momentum, scale);
SGDFunctor<3, float, float>(), stream, wd, momentum, dampening, lr,
nesterov, first_run, wd_after_momentum, scale);
}
// Case 3. fp16, fp32, fp32, Yes
else if (grad_type == DType::kFloat16 && // NOLINT(*)
weight_type == DType::kFloat32 && num_tensor_lists == 4) {
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
SGDFunctor<4, fp16, float>(), device_id, stream, wd, momentum, dampening,
lr, nesterov, first_run, wd_after_momentum, scale);
SGDFunctor<4, fp16, float>(), stream, wd, momentum, dampening, lr,
nesterov, first_run, wd_after_momentum, scale);
}
// Case 4. fp32, fp32, fp32, Yes
else if (grad_type == DType::kFloat32 && // NOLINT(*)
weight_type == DType::kFloat32 && num_tensor_lists == 4) {
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
SGDFunctor<4, float, float>(), device_id, stream, wd, momentum, dampening,
lr, nesterov, first_run, wd_after_momentum, scale);
SGDFunctor<4, float, float>(), stream, wd, momentum, dampening, lr,
nesterov, first_run, wd_after_momentum, scale);
} else {
NVTE_ERROR("Unsupported combination of weight and gradient types.");
}
......@@ -191,12 +190,12 @@ void nvte_multi_tensor_sgd_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor
const size_t num_tensor_lists, const size_t num_tensors_per_list,
float wd, float momentum, float dampening, float lr, int nesterov,
int first_run, int wd_after_momentum, float scale,
const int device_id, cudaStream_t stream) {
cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_sgd_cuda);
using namespace transformer_engine;
multi_tensor_sgd::multi_tensor_sgd_cuda(
chunk_size, *convertNVTETensorCheck(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), wd, momentum,
dampening, lr, nesterov, first_run, wd_after_momentum, scale, device_id, stream);
dampening, lr, nesterov, first_run, wd_after_momentum, scale, stream);
}
......@@ -16,11 +16,10 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_adam_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists,
num_tensors, lr, beta1, beta2, epsilon, step, mode, bias_correction,
weight_decay, device_id, at::cuda::getCurrentCUDAStream());
weight_decay, at::cuda::getCurrentCUDAStream());
}
void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag,
......@@ -31,12 +30,10 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_adam_param_remainder_cuda(
chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, lr, beta1,
beta2, epsilon, step, mode, bias_correction, weight_decay, device_id,
at::cuda::getCurrentCUDAStream());
beta2, epsilon, step, mode, bias_correction, weight_decay, at::cuda::getCurrentCUDAStream());
}
void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag,
......@@ -47,12 +44,11 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag,
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_adam_fp8_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(),
num_lists, num_tensors, lr, beta1, beta2, epsilon, step, mode,
bias_correction, weight_decay, static_cast<NVTEDType>(fp8_dtype),
device_id, at::cuda::getCurrentCUDAStream());
at::cuda::getCurrentCUDAStream());
}
void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag,
......@@ -67,12 +63,11 @@ void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag,
auto lr_cu = makeTransformerEngineTensor(lr);
auto step_cu = makeTransformerEngineTensor(step);
auto inv_scale_cu = makeTransformerEngineTensor(inv_scale);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_adam_capturable_cuda(
chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors,
lr_cu.data(), beta1, beta2, epsilon, step_cu.data(), mode, bias_correction, weight_decay,
inv_scale_cu.data(), device_id, at::cuda::getCurrentCUDAStream());
inv_scale_cu.data(), at::cuda::getCurrentCUDAStream());
}
void multi_tensor_adam_capturable_master_cuda(int chunk_size, at::Tensor noop_flag,
......@@ -87,12 +82,11 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, at::Tensor noop_fl
auto lr_cu = makeTransformerEngineTensor(lr);
auto step_cu = makeTransformerEngineTensor(step);
auto inv_scale_cu = makeTransformerEngineTensor(inv_scale);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_adam_capturable_master_cuda(
chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors,
lr_cu.data(), beta1, beta2, epsilon, step_cu.data(), mode, bias_correction, weight_decay,
inv_scale_cu.data(), device_id, at::cuda::getCurrentCUDAStream());
inv_scale_cu.data(), at::cuda::getCurrentCUDAStream());
}
} // namespace transformer_engine::pytorch
......@@ -14,11 +14,10 @@ void multi_tensor_compute_scale_and_scale_inv_cuda(
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_compute_scale_and_scale_inv_cuda(
chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, max_fp8,
force_pow_2_scales, epsilon, device_id, at::cuda::getCurrentCUDAStream());
force_pow_2_scales, epsilon, at::cuda::getCurrentCUDAStream());
}
} // namespace transformer_engine::pytorch
......@@ -43,12 +43,11 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
auto output_per_tensor_cu = makeTransformerEngineTensor(output_per_tensor);
auto ret_cu = makeTransformerEngineTensor(ret);
auto ret_per_tensor_cu = makeTransformerEngineTensor(ret_per_tensor);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_l2norm_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists,
num_tensors, output_cu.data(), output_per_tensor_cu.data(),
ret_cu.data(), ret_per_tensor_cu.data(), per_tensor,
max_chunks_per_tensor, device_id, at::cuda::getCurrentCUDAStream());
max_chunks_per_tensor, at::cuda::getCurrentCUDAStream());
return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor);
}
......@@ -91,13 +90,11 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda(
auto ret_cu = makeTransformerEngineTensor(ret);
auto ret_per_tensor_cu = makeTransformerEngineTensor(ret_per_tensor);
auto inv_scale_cu = makeTransformerEngineTensor(inv_scale);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_unscale_l2norm_cuda(
chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors,
output_cu.data(), output_per_tensor_cu.data(), ret_cu.data(), ret_per_tensor_cu.data(),
inv_scale_cu.data(), per_tensor, max_chunks_per_tensor, device_id,
at::cuda::getCurrentCUDAStream());
inv_scale_cu.data(), per_tensor, max_chunks_per_tensor, at::cuda::getCurrentCUDAStream());
return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor);
}
......
......@@ -13,10 +13,9 @@ void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag,
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_scale_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists,
num_tensors, scale, device_id, at::cuda::getCurrentCUDAStream());
num_tensors, scale, at::cuda::getCurrentCUDAStream());
}
} // namespace transformer_engine::pytorch
......@@ -15,11 +15,10 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_sgd_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists,
num_tensors, wd, momentum, dampening, lr, nesterov, first_run,
wd_after_momentum, scale, device_id, at::cuda::getCurrentCUDAStream());
wd_after_momentum, scale, at::cuda::getCurrentCUDAStream());
}
} // namespace transformer_engine::pytorch
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment