Unverified Commit c972f5a7 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[C][PyTorch] Move multi tensors kernels from PyTorch extensions to core (#1744)



* Move multi tensors kernels from PyTorch extensions to core
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add int16 type to core (for storing fp32 param remainders)
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix core build
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* same fix to scale
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix perf, memory, vars
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Re-add device guard for multi-device
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix junk output dtype for non-per tensor
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fixes for test and upgrade mcore version
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix core tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent e17fab14
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, float scale) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_scale_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists,
num_tensors, scale, device_id, at::cuda::getCurrentCUDAStream());
}
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, float wd,
float momentum, float dampening, float lr, bool nesterov, bool first_run,
bool wd_after_momentum, float scale) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_sgd_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists,
num_tensors, wd, momentum, dampening, lr, nesterov, first_run,
wd_after_momentum, scale, device_id, at::cuda::getCurrentCUDAStream());
}
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment