Unverified Commit c972f5a7 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[C][PyTorch] Move multi tensors kernels from PyTorch extensions to core (#1744)



* Move multi tensors kernels from PyTorch extensions to core
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add int16 type to core (for storing fp32 param remainders)
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix core build
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* same fix to scale
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix perf, memory, vars
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Re-add device guard for multi-device
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix junk output dtype for non-per tensor
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fixes for test and upgrade mcore version
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix core tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent e17fab14
......@@ -17,7 +17,7 @@ fi
# Download Megatron-LM if needed
if [ ! -d "${MCORE_PATH}" ]; then
pushd $(dirname ${MCORE_PATH})
git clone -b core_r0.9.0 https://github.com/NVIDIA/Megatron-LM.git Megatron-LM
git clone -b core_r0.12.0 https://github.com/NVIDIA/Megatron-LM.git Megatron-LM
popd
fi
......
......@@ -46,6 +46,7 @@ struct BytesToType<8> {
};
using byte = uint8_t;
using int16 = int16_t;
using int32 = int32_t;
using int64 = int64_t;
using fp32 = float;
......@@ -58,6 +59,7 @@ using fp8e8m0 = uint8_t;
template <typename T>
struct TypeInfo{
using types = std::tuple<byte,
int16,
int32,
int64,
fp32,
......
......@@ -159,7 +159,7 @@ def test_multi_tensor_l2norm(input_size_pair, applier, repeat, in_type, per_tens
normab = torch.cat((a.norm().view(1), b.norm().view(1)))
norm_per_tensor = norm_per_tensor.view(-1, 2)
else:
norm, _ = applier(tex.multi_tensor_l2norm, overflow_buf, [in_list], True)
norm, _ = applier(tex.multi_tensor_l2norm, overflow_buf, [in_list], False)
reference = torch.full(
[(sizea + sizeb) * repeat], val, dtype=torch.float32, device=device
......
......@@ -53,6 +53,11 @@ list(APPEND transformer_engine_SOURCES
cudnn_utils.cpp
transformer_engine.cpp
common.cu
multi_tensor/adam.cu
multi_tensor/compute_scale.cu
multi_tensor/l2norm.cu
multi_tensor/scale.cu
multi_tensor/sgd.cu
transpose/cast_transpose.cu
transpose/transpose.cu
transpose/cast_transpose_fusion.cu
......@@ -163,6 +168,11 @@ target_include_directories(transformer_engine PRIVATE
set_source_files_properties(fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu
multi_tensor/adam.cu
multi_tensor/compute_scale.cu
multi_tensor/l2norm.cu
multi_tensor/scale.cu
multi_tensor/sgd.cu
PROPERTIES
COMPILE_OPTIONS "--use_fast_math")
option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --use_fast_math option" OFF)
......
......@@ -193,4 +193,16 @@ bool is_supported_by_CC_100() {
return deviceComputeCapability >= 100;
}
std::vector<std::vector<Tensor *>> convert_tensor_array(NVTETensor **nvte_tensors,
size_t outer_size, size_t inner_size) {
std::vector<std::vector<Tensor *>> ret;
for (size_t i = 0; i < outer_size; ++i) {
ret.emplace_back();
for (size_t j = 0; j < inner_size; ++j) {
ret.back().push_back(reinterpret_cast<Tensor *>(nvte_tensors[i][j]));
}
}
return ret;
}
} // namespace transformer_engine
......@@ -238,6 +238,7 @@ constexpr T DIVUP(const T &x, const T &y) {
}
using byte = uint8_t;
using int16 = int16_t;
using int32 = int32_t;
using int64 = int64_t;
using fp32 = float;
......@@ -260,6 +261,7 @@ constexpr inline const char *type_name() noexcept;
return #T; \
}
TRANSFORMER_ENGINE_TYPE_NAME(uint8_t)
TRANSFORMER_ENGINE_TYPE_NAME(int16_t)
TRANSFORMER_ENGINE_TYPE_NAME(int32_t)
TRANSFORMER_ENGINE_TYPE_NAME(int64_t)
TRANSFORMER_ENGINE_TYPE_NAME(float)
......@@ -306,7 +308,7 @@ struct TypeExtrema {
template <typename T>
struct TypeInfo {
using types = std::tuple<byte, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2>;
using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2>;
template <typename U, DType current>
struct Helper {
......@@ -343,6 +345,10 @@ struct TypeInfo {
using type = unsigned char; \
{ __VA_ARGS__ } \
} break; \
case DType::kInt16: { \
using type = int16_t; \
{ __VA_ARGS__ } \
} break; \
case DType::kInt32: { \
using type = int32_t; \
{ __VA_ARGS__ } \
......@@ -576,6 +582,9 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
bool is_supported_by_CC_100();
std::vector<std::vector<Tensor *>> convert_tensor_array(NVTETensor **nvte_tensors,
size_t outer_size, size_t inner_size);
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_COMMON_H_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file multi_tensor.h
* \brief Functions handling multi tensor kernels.
*/
#ifndef TRANSFORMER_ENGINE_MULTI_TENSOR_H_
#define TRANSFORMER_ENGINE_MULTI_TENSOR_H_
#include "transformer_engine.h"
#ifdef __cplusplus
extern "C" {
#endif
void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list,
NVTETensor output, NVTETensor output_per_tensor, NVTETensor ret,
NVTETensor ret_per_tensor, int per_tensor,
int max_chunks_per_tensor, const int device_id,
cudaStream_t stream);
void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, NVTETensor output,
NVTETensor output_per_tensor, NVTETensor ret,
NVTETensor ret_per_tensor, NVTETensor inv_scale,
int per_tensor, int max_chunks_per_tensor,
const int device_id, cudaStream_t stream);
void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list,
const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode,
const int bias_correction, const float weight_decay,
const int device_id, cudaStream_t stream);
void nvte_multi_tensor_adam_param_remainder_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode, const int bias_correction,
const float weight_decay, const int device_id, cudaStream_t stream);
void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag,
NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, const float lr,
const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction,
const float weight_decay, const NVTEDType fp8_dtype,
const int device_id, cudaStream_t stream);
void nvte_multi_tensor_adam_capturable_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2,
const float epsilon, NVTETensor step, const int mode, const int bias_correction,
const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream);
void nvte_multi_tensor_adam_capturable_master_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2,
const float epsilon, NVTETensor step, const int mode, const int bias_correction,
const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream);
void nvte_multi_tensor_sgd_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list,
float wd, float momentum, float dampening, float lr, int nesterov,
int first_run, int wd_after_momentum, float scale,
const int device_id, cudaStream_t stream);
void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list,
float scale, const int device_id, cudaStream_t stream);
void nvte_multi_tensor_compute_scale_and_scale_inv_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, float max_fp8, int force_pow_2_scales, float epsilon,
const int device_id, cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
#endif // TRANSFORMER_ENGINE_MULTI_TENSOR_H_
......@@ -23,14 +23,15 @@ extern "C" {
*/
enum NVTEDType {
kNVTEByte = 0, /*!< Byte */
kNVTEInt32 = 1, /*!< 32-bit integer */
kNVTEInt64 = 2, /*!< 64-bit integer */
kNVTEFloat32 = 3, /*!< 32-bit float */
kNVTEFloat16 = 4, /*!< 16-bit float (E5M10) */
kNVTEBFloat16 = 5, /*!< 16-bit bfloat (E8M7) */
kNVTEFloat8E4M3 = 6, /*!< 8-bit float (E4M3) */
kNVTEFloat8E5M2 = 7, /*!< 8-bit float (E5M2) */
kNVTEFloat8E8M0 = 8, /*!< 8-bit float (E8M0) */
kNVTEInt16 = 1, /*!< 16-bit integer */
kNVTEInt32 = 2, /*!< 32-bit integer */
kNVTEInt64 = 3, /*!< 64-bit integer */
kNVTEFloat32 = 4, /*!< 32-bit float */
kNVTEFloat16 = 5, /*!< 16-bit float (E5M10) */
kNVTEBFloat16 = 6, /*!< 16-bit bfloat (E8M7) */
kNVTEFloat8E4M3 = 7, /*!< 8-bit float (E4M3) */
kNVTEFloat8E5M2 = 8, /*!< 8-bit float (E5M2) */
kNVTEFloat8E8M0 = 9, /*!< 8-bit float (E8M0) */
kNVTENumTypes /*!< Number of supported types */
};
......@@ -373,14 +374,15 @@ namespace transformer_engine {
*/
enum class DType {
kByte = 0,
kInt32 = 1,
kInt64 = 2,
kFloat32 = 3,
kFloat16 = 4,
kBFloat16 = 5,
kFloat8E4M3 = 6,
kFloat8E5M2 = 7,
kFloat8E8M0 = 8,
kInt16 = 1,
kInt32 = 2,
kInt64 = 3,
kFloat32 = 4,
kFloat16 = 5,
kBFloat16 = 6,
kFloat8E4M3 = 7,
kFloat8E5M2 = 8,
kFloat8E8M0 = 9,
kNumTypes
};
......
......@@ -4,19 +4,16 @@
* See LICENSE for license information.
************************************************************************/
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <cuda_fp8.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include <cuda_fp8.h>
#include <transformer_engine/multi_tensor.h>
#include <transformer_engine/transformer_engine.h>
#include "common/utils.cuh"
#include "../utils.cuh"
#include "multi_tensor_apply.cuh"
#include "type_shim.h"
namespace transformer_engine {
namespace multi_tensor_adam {
#define BLOCK_SIZE 512
#define ILP 4
......@@ -30,7 +27,6 @@ typedef enum {
using MATH_T = float;
using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2;
using transformer_engine::DType;
template <typename T>
struct is_fp8 : std::false_type {};
......@@ -576,12 +572,13 @@ struct AdamCapturableMasterFunctor {
}
};
void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, const float lr,
void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor *>> tensor_lists, const float lr,
const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction,
const float weight_decay) {
using namespace at;
const float weight_decay, const int device_id, cudaStream_t stream) {
const size_t num_tensor_lists = tensor_lists.size();
const size_t num_tensors_per_list = tensor_lists[0].size();
// Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
......@@ -592,10 +589,10 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
size_t max_size = 0;
bool requires_64bit_indexing = false;
for (auto it = tensor_lists.begin(); it != tensor_lists.end(); it++) {
for (auto it2 = it->begin(); it2 != it->end(); it2++) {
if (it2->numel() > max_size) {
max_size = it2->numel();
for (size_t i = 0; i < num_tensor_lists; i++) {
for (size_t j = 0; j < num_tensors_per_list; j++) {
if (tensor_lists[i][j]->numel() > max_size) {
max_size = tensor_lists[i][j]->numel();
if (max_size >= INT_MAX) {
requires_64bit_indexing = true;
break;
......@@ -607,69 +604,70 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
}
}
const auto g_in_type = tensor_lists[0][0].scalar_type();
const auto p_in_type = tensor_lists[1][0].scalar_type();
auto tl_size = tensor_lists.size();
const auto g_in_type_te = tensor_lists[0][0]->dtype();
const auto p_in_type_te = tensor_lists[1][0]->dtype();
// case 4: g, p, m, v
// case 5: g, p, m, v, p_master
TORCH_CHECK(tl_size == 4 || tl_size == 5, "tensor list must contain 4 or 5");
NVTE_CHECK(num_tensor_lists == 4 || num_tensor_lists == 5, "tensor list must contain 4 or 5");
if (requires_64bit_indexing) {
if (tl_size == 4) {
if (num_tensor_lists == 4) {
// Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
p_in_type, 0, "adam",
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
g_in_type, 1, "adam",
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
p_in_type_te, p_in_type,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
g_in_type_te, g_in_type,
multi_tensor_apply<4>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag,
tensor_lists,
AdamFunctor<scalar_t_0, scalar_t_1, float, int64_t>(), beta1,
beta2, bias_correction1, bias_correction2, epsilon, lr,
(adamMode_t)mode, weight_decay);));
AdamFunctor<p_in_type, g_in_type, float, int64_t>(), device_id,
stream, beta1, beta2, bias_correction1, bias_correction2,
epsilon, lr, (adamMode_t)mode, weight_decay);));
} else {
// g, p, m, v, p_master
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
p_in_type, 0, "adam",
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
g_in_type, 1, "adam",
multi_tensor_apply<5>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag,
tensor_lists,
AdamFunctorMaster<scalar_t_0, scalar_t_1, float, int64_t>(),
beta1, beta2, bias_correction1, bias_correction2, epsilon, lr,
(adamMode_t)mode, weight_decay);));
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
p_in_type_te, p_in_type,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
g_in_type_te, g_in_type,
multi_tensor_apply<5>(
(int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists,
AdamFunctorMaster<p_in_type, g_in_type, float, int64_t>(), device_id, stream,
beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode,
weight_decay);));
}
} else {
if (tl_size == 4) {
if (num_tensor_lists == 4) {
// Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
p_in_type, 0, "adam",
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
g_in_type, 1, "adam",
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
p_in_type_te, p_in_type,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
g_in_type_te, g_in_type,
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamFunctor<scalar_t_0, scalar_t_1, float, int32_t>(), beta1,
beta2, bias_correction1, bias_correction2, epsilon, lr,
(adamMode_t)mode, weight_decay);));
AdamFunctor<p_in_type, g_in_type, float, int32_t>(), device_id,
stream, beta1, beta2, bias_correction1, bias_correction2,
epsilon, lr, (adamMode_t)mode, weight_decay);));
} else {
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
p_in_type, 0, "adam",
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
g_in_type, 1, "adam",
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
p_in_type_te, p_in_type,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
g_in_type_te, g_in_type,
multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamFunctorMaster<scalar_t_0, scalar_t_1, float, int32_t>(),
beta1, beta2, bias_correction1, bias_correction2, epsilon, lr,
(adamMode_t)mode, weight_decay);));
AdamFunctorMaster<p_in_type, g_in_type, float, int32_t>(),
device_id, stream, beta1, beta2, bias_correction1,
bias_correction2, epsilon, lr, (adamMode_t)mode,
weight_decay);));
}
}
AT_CUDA_CHECK(cudaGetLastError());
NVTE_CHECK_CUDA(cudaGetLastError());
}
void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
void multi_tensor_adam_param_remainder_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor *>> tensor_lists,
const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode,
const int bias_correction, const float weight_decay) {
using namespace at;
const int bias_correction, const float weight_decay,
const int device_id, cudaStream_t stream) {
const size_t num_tensor_lists = tensor_lists.size();
// Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
......@@ -678,34 +676,34 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag
bias_correction2 = 1 - std::pow(beta2, step);
}
const auto g_in_type = tensor_lists[0][0].scalar_type();
const auto p_in_type = tensor_lists[1][0].scalar_type();
auto tl_size = tensor_lists.size();
const auto g_in_type_te = tensor_lists[0][0]->dtype();
const auto p_in_type_te = tensor_lists[1][0]->dtype();
// case 5: g, p, m, v, p_master
TORCH_CHECK(tl_size == 5, "tensor list must contain 5");
TORCH_CHECK(p_in_type == at::ScalarType::BFloat16,
NVTE_CHECK(num_tensor_lists == 5, "tensor list must contain 5");
NVTE_CHECK(p_in_type_te == DType::kBFloat16,
"Adam with BF16 param remainders requires BF16 params");
// g, p, m, v, p_master
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
p_in_type, 0, "adam",
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
g_in_type, 1, "adam",
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
g_in_type_te, g_in_type,
multi_tensor_apply<5>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists,
AdamFunctorMasterParamRemainder<scalar_t_1, float, int64_t>(),
beta1, beta2, bias_correction1, bias_correction2, epsilon, lr,
(adamMode_t)mode, weight_decay);));
AdamFunctorMasterParamRemainder<g_in_type, float, int64_t>(), device_id,
stream, beta1, beta2, bias_correction1, bias_correction2, epsilon, lr,
(adamMode_t)mode, weight_decay););
AT_CUDA_CHECK(cudaGetLastError());
NVTE_CHECK_CUDA(cudaGetLastError());
}
void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, const float lr,
void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor *>> tensor_lists, const float lr,
const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction,
const float weight_decay, DType fp8_dtype) {
using namespace at;
const float weight_decay, const DType fp8_dtype,
const int device_id, cudaStream_t stream) {
const size_t num_tensor_lists = tensor_lists.size();
const size_t num_tensors_per_list = tensor_lists[0].size();
// Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
......@@ -716,10 +714,10 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag,
size_t max_size = 0;
bool requires_64bit_indexing = false;
for (auto it = tensor_lists.begin(); it != tensor_lists.end(); it++) {
for (auto it2 = it->begin(); it2 != it->end(); it2++) {
if (it2->numel() > max_size) {
max_size = it2->numel();
for (size_t i = 0; i < num_tensor_lists; i++) {
for (size_t j = 0; j < num_tensors_per_list; j++) {
if (tensor_lists[i][j]->numel() > max_size) {
max_size = tensor_lists[i][j]->numel();
if (max_size >= INT_MAX) {
requires_64bit_indexing = true;
break;
......@@ -731,66 +729,147 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag,
}
}
const auto g_in_type = tensor_lists[0][0].scalar_type();
auto tl_size = tensor_lists.size();
const auto g_in_type_te = tensor_lists[0][0]->dtype();
// case 8: g, p_fp8, m, v, p_master, scale, amax, scale_inv
TORCH_CHECK(tl_size == 8, "tensor list must contain 8 tensors");
NVTE_CHECK(num_tensor_lists == 8, "tensor list must contain 8 tensors");
if (requires_64bit_indexing) {
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
fp8_dtype, FP8_T,
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
g_in_type, 0, "adam",
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
g_in_type_te, g_in_type,
multi_tensor_apply<5, true>(
(int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists,
AdamFunctorMaster<FP8_T, scalar_t_0, float, int64_t>(), beta1, beta2,
bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, weight_decay);));
AdamFunctorMaster<FP8_T, g_in_type, float, int64_t>(), device_id, stream, beta1,
beta2, bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode,
weight_decay);));
} else {
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
fp8_dtype, FP8_T,
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
g_in_type, 0, "adam",
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
g_in_type_te, g_in_type,
multi_tensor_apply<5, true>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamFunctorMaster<FP8_T, scalar_t_0, float, int32_t>(),
beta1, beta2, bias_correction1, bias_correction2, epsilon,
lr, (adamMode_t)mode, weight_decay);));
AdamFunctorMaster<FP8_T, g_in_type, float, int32_t>(),
device_id, stream, beta1, beta2, bias_correction1,
bias_correction2, epsilon, lr, (adamMode_t)mode,
weight_decay);));
}
AT_CUDA_CHECK(cudaGetLastError());
NVTE_CHECK_CUDA(cudaGetLastError());
}
void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor lr, const float beta1, const float beta2,
const float epsilon, at::Tensor step, const int mode,
const int bias_correction, const float weight_decay,
at::Tensor inv_scale) {
using namespace at;
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
tensor_lists[0][0].scalar_type(), 0, "adam",
void multi_tensor_adam_capturable_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor *>> tensor_lists, Tensor lr,
const float beta1, const float beta2, const float epsilon,
Tensor step, const int mode, const int bias_correction,
const float weight_decay, Tensor inv_scale,
const int device_id, cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[0][0]->dtype(), dtype,
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamCapturableFunctor<scalar_t_0, float>(), beta1, beta2,
step.data_ptr<int>(), bias_correction, epsilon, lr.data_ptr<float>(),
(adamMode_t)mode, weight_decay, inv_scale.data_ptr<float>());)
AdamCapturableFunctor<dtype, float>(), device_id, stream, beta1, beta2,
reinterpret_cast<int *>(step.data.dptr), bias_correction, epsilon,
reinterpret_cast<float *>(lr.data.dptr), (adamMode_t)mode, weight_decay,
reinterpret_cast<float *>(inv_scale.data.dptr));)
AT_CUDA_CHECK(cudaGetLastError());
NVTE_CHECK_CUDA(cudaGetLastError());
}
void multi_tensor_adam_capturable_master_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor lr, const float beta1, const float beta2,
const float epsilon, at::Tensor step, const int mode,
void multi_tensor_adam_capturable_master_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor *>> tensor_lists,
Tensor lr, const float beta1, const float beta2,
const float epsilon, Tensor step, const int mode,
const int bias_correction, const float weight_decay,
at::Tensor inv_scale) {
using namespace at;
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
tensor_lists[0][0].scalar_type(), 0, "adam",
Tensor inv_scale, const int device_id,
cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[0][0]->dtype(), dtype,
multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamCapturableMasterFunctor<scalar_t_0, float>(), beta1, beta2,
step.data_ptr<int>(), bias_correction, epsilon, lr.data_ptr<float>(),
(adamMode_t)mode, weight_decay, inv_scale.data_ptr<float>());)
AdamCapturableMasterFunctor<dtype, float>(), device_id, stream, beta1,
beta2, reinterpret_cast<int *>(step.data.dptr), bias_correction,
epsilon, reinterpret_cast<float *>(lr.data.dptr), (adamMode_t)mode,
weight_decay, reinterpret_cast<float *>(inv_scale.data.dptr));)
NVTE_CHECK_CUDA(cudaGetLastError());
}
} // namespace multi_tensor_adam
} // namespace transformer_engine
void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list,
const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode,
const int bias_correction, const float weight_decay,
const int device_id, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_adam_cuda);
using namespace transformer_engine;
multi_tensor_adam::multi_tensor_adam_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2,
epsilon, step, mode, bias_correction, weight_decay, device_id, stream);
}
void nvte_multi_tensor_adam_param_remainder_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode, const int bias_correction,
const float weight_decay, const int device_id, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_adam_param_remainder_cuda);
using namespace transformer_engine;
multi_tensor_adam::multi_tensor_adam_param_remainder_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2,
epsilon, step, mode, bias_correction, weight_decay, device_id, stream);
}
void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag,
NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, const float lr,
const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction,
const float weight_decay, const NVTEDType fp8_dtype,
const int device_id, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_adam_fp8_cuda);
using namespace transformer_engine;
multi_tensor_adam::multi_tensor_adam_fp8_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2,
epsilon, step, mode, bias_correction, weight_decay, static_cast<DType>(fp8_dtype), device_id,
stream);
}
void nvte_multi_tensor_adam_capturable_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2,
const float epsilon, NVTETensor step, const int mode, const int bias_correction,
const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_adam_capturable_cuda);
using namespace transformer_engine;
multi_tensor_adam::multi_tensor_adam_capturable_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
*reinterpret_cast<Tensor *>(lr), beta1, beta2, epsilon, *reinterpret_cast<Tensor *>(step),
mode, bias_correction, weight_decay, *reinterpret_cast<Tensor *>(inv_scale), device_id,
stream);
}
AT_CUDA_CHECK(cudaGetLastError());
void nvte_multi_tensor_adam_capturable_master_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2,
const float epsilon, NVTETensor step, const int mode, const int bias_correction,
const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_adam_capturable_master_cuda);
using namespace transformer_engine;
multi_tensor_adam::multi_tensor_adam_capturable_master_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
*reinterpret_cast<Tensor *>(lr), beta1, beta2, epsilon, *reinterpret_cast<Tensor *>(step),
mode, bias_correction, weight_decay, *reinterpret_cast<Tensor *>(inv_scale), device_id,
stream);
}
......@@ -4,23 +4,21 @@
* See LICENSE for license information.
************************************************************************/
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include <limits>
// Stringstream is a big hammer, but I want to rely on operator<< for dtype.
#include <assert.h>
#include <cuda_fp8.h>
#include <transformer_engine/multi_tensor.h>
#include <transformer_engine/transformer_engine.h>
#include <sstream>
#include "common/recipe/recipe_common.cuh"
#include "common/utils.cuh"
#include "../recipe/recipe_common.cuh"
#include "../utils.cuh"
#include "multi_tensor_apply.cuh"
#include "type_shim.h"
namespace transformer_engine {
namespace multi_tensor_compute_scale {
#define BLOCK_SIZE 256
......@@ -57,12 +55,29 @@ struct ComputeScaleAndScaleInvFunctor {
}
};
void multi_tensor_compute_scale_and_scale_inv_cuda(
int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
float max_fp8, bool force_pow_2_scales, float epsilon) {
using namespace at;
void multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor *>> tensor_lists,
float max_fp8, bool force_pow_2_scales,
float epsilon, const int device_id,
cudaStream_t stream) {
multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
ComputeScaleAndScaleInvFunctor(), max_fp8, force_pow_2_scales, epsilon);
AT_CUDA_CHECK(cudaGetLastError());
ComputeScaleAndScaleInvFunctor(), device_id, stream, max_fp8,
force_pow_2_scales, epsilon);
NVTE_CHECK_CUDA(cudaGetLastError());
}
} // namespace multi_tensor_compute_scale
} // namespace transformer_engine
void nvte_multi_tensor_compute_scale_and_scale_inv_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, float max_fp8, int force_pow_2_scales, float epsilon,
const int device_id, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_compute_scale_and_scale_inv_cuda);
using namespace transformer_engine;
multi_tensor_compute_scale::multi_tensor_compute_scale_and_scale_inv_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), max_fp8,
force_pow_2_scales, epsilon, device_id, stream);
}
......@@ -4,18 +4,16 @@
* See LICENSE for license information.
************************************************************************/
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include <cuda_fp8.h>
#include <transformer_engine/multi_tensor.h>
#include <transformer_engine/transformer_engine.h>
#include "../utils.cuh"
#include "multi_tensor_apply.cuh"
#include "type_shim.h"
namespace transformer_engine {
namespace multi_tensor_l2norm {
#define BLOCK_SIZE 512
#define ILP 4
......@@ -31,6 +29,88 @@ __device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset, int s
((LT *)dst)[dst_offset] = ((LT *)src)[src_offset]; // NOLINT(*)
}
template <typename T>
__device__ __forceinline__ T
reduce_block_into_lanes(T *x, T val, int lanes = 1,
bool share_result = false) { // lanes is intended to be <= 32.
int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
if (blockSize >= 64) {
x[tid] = val;
__syncthreads();
}
#pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
if (tid < i) x[tid] = x[tid] + x[tid + i];
__syncthreads();
}
T final;
if (tid < 32) {
if (blockSize >= 64)
final = x[tid] + x[tid + 32];
else
final = val;
// __SYNCWARP();
#pragma unroll
for (int i = 16; i >= lanes; i >>= 1) final = final + __shfl_down_sync(0xffffffff, final, i);
}
if (share_result) {
if (tid < lanes) x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
}
__syncthreads();
// Avoid potential write before read race when reduce_block_into_lanes is called back to back
return final;
}
template <typename T>
__device__ __forceinline__ T
reduce_block_into_lanes_max_op(T *x, T val, int lanes = 1,
bool share_result = false) { // lanes is intended to be <= 32.
int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
if (blockSize >= 64) {
x[tid] = val;
__syncthreads();
}
#pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
if (tid < i) x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i]));
__syncthreads();
}
T final;
if (tid < 32) {
if (blockSize >= 64)
final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32]));
else
final = val;
// __SYNCWARP();
#pragma unroll
for (int i = 16; i >= lanes; i >>= 1)
final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
}
if (share_result) {
if (tid < lanes) x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}
template <typename x_t>
struct L2NormFunctor {
__device__ __forceinline__ void operator()(int chunk_size, volatile int *noop_gmem,
......@@ -56,7 +136,7 @@ struct L2NormFunctor {
x_t r_x[ILP];
for (int i = 0; i < ILP; i++) {
vals[i] = 0.f;
r_x[i] = 0;
r_x[i] = 0.f;
}
// to make things simple, we put aligned case in a different code path
......@@ -126,7 +206,7 @@ struct UnscaleL2NormFunctor {
x_t r_x[ILP];
for (int i = 0; i < ILP; i++) {
vals[i] = 0.f;
r_x[i] = 0;
r_x[i] = 0.f;
}
// to make things simple, we put aligned case in a different code path
......@@ -310,103 +390,96 @@ __global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret,
}
}
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python) {
bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false;
auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
auto output = at::zeros({320}, float_options);
at::Tensor output_per_tensor;
at::Tensor ret_per_tensor;
int ntensors = tensor_lists[0].size();
int max_chunks_per_tensor = -1;
if (per_tensor) {
for (int t = 0; t < ntensors; t++) {
int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
if (max_chunks_this_tensor > max_chunks_per_tensor)
max_chunks_per_tensor = max_chunks_this_tensor;
}
output_per_tensor = at::zeros({ntensors * max_chunks_per_tensor}, float_options);
ret_per_tensor = at::empty({ntensors}, float_options);
} else {
ret_per_tensor = at::empty({0}, float_options);
}
DISPATCH_FLOAT_HALF_AND_BFLOAT(
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda",
multi_tensor_apply<1>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
L2NormFunctor<scalar_t_0>(), output.data_ptr<float>(),
per_tensor ? output_per_tensor.data_ptr<float>() : nullptr, per_tensor,
void multi_tensor_l2norm_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor *>> tensor_lists, Tensor output,
Tensor output_per_tensor, Tensor ret, Tensor ret_per_tensor,
bool per_tensor, int max_chunks_per_tensor, const int device_id,
cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[0][0]->dtype(), dtype,
multi_tensor_apply<1>(
BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, L2NormFunctor<dtype>(), device_id,
stream, reinterpret_cast<float *>(output.data.dptr),
per_tensor ? reinterpret_cast<float *>(output_per_tensor.data.dptr) : nullptr, per_tensor,
max_chunks_per_tensor);)
AT_CUDA_CHECK(cudaGetLastError());
// AT_CUDA_CHECK(cudaDeviceSynchronize());
NVTE_CHECK_CUDA(cudaGetLastError());
// This involves one more small kernel launches, but will be negligible end to end.
// I could get rid of these by hacking the functor + multi tensor harness with persistence
// logic, but keeping it simple for now
auto ret = at::empty({1}, output.options());
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
auto stream = at::cuda::getCurrentCUDAStream();
cleanup<<<per_tensor ? ntensors : 1, 512, 0, stream>>>(
output.data_ptr<float>(), per_tensor ? output_per_tensor.data_ptr<float>() : nullptr,
ret.data_ptr<float>(), per_tensor ? ret_per_tensor.data_ptr<float>() : nullptr, per_tensor,
const OptionalCUDAGuard device_guard(device_id);
cleanup<<<per_tensor ? tensor_lists[0].size() : 1, 512, 0, stream>>>(
reinterpret_cast<float *>(output.data.dptr),
per_tensor ? reinterpret_cast<float *>(output_per_tensor.data.dptr) : nullptr,
reinterpret_cast<float *>(ret.data.dptr),
per_tensor ? reinterpret_cast<float *>(ret_per_tensor.data.dptr) : nullptr, per_tensor,
max_chunks_per_tensor);
return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor);
}
std::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda(
int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor inv_scale, at::optional<bool> per_tensor_python) {
bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false;
auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
auto output = at::zeros({320}, float_options);
at::Tensor output_per_tensor;
at::Tensor ret_per_tensor;
int ntensors = tensor_lists[0].size();
int max_chunks_per_tensor = -1;
if (per_tensor) {
for (int t = 0; t < ntensors; t++) {
int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
if (max_chunks_this_tensor > max_chunks_per_tensor)
max_chunks_per_tensor = max_chunks_this_tensor;
}
output_per_tensor = at::zeros({ntensors * max_chunks_per_tensor}, float_options);
ret_per_tensor = at::empty({ntensors}, float_options);
} else {
ret_per_tensor = at::empty({0}, float_options);
}
DISPATCH_FLOAT_HALF_AND_BFLOAT(
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_unscale_l2norm_cuda",
multi_tensor_apply<1>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
UnscaleL2NormFunctor<scalar_t_0>(), inv_scale.data_ptr<float>(),
output.data_ptr<float>(),
per_tensor ? output_per_tensor.data_ptr<float>() : nullptr, per_tensor,
void multi_tensor_unscale_l2norm_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor *>> tensor_lists,
Tensor output, Tensor output_per_tensor, Tensor ret,
Tensor ret_per_tensor, Tensor inv_scale, bool per_tensor,
int max_chunks_per_tensor, const int device_id,
cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[0][0]->dtype(), dtype,
multi_tensor_apply<1>(
BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, UnscaleL2NormFunctor<dtype>(), device_id,
stream, reinterpret_cast<float *>(inv_scale.data.dptr),
reinterpret_cast<float *>(output.data.dptr),
per_tensor ? reinterpret_cast<float *>(output_per_tensor.data.dptr) : nullptr, per_tensor,
max_chunks_per_tensor);)
AT_CUDA_CHECK(cudaGetLastError());
// AT_CUDA_CHECK(cudaDeviceSynchronize());
NVTE_CHECK_CUDA(cudaGetLastError());
// This involves one more small kernel launches, but will be negligible end to end.
// I could get rid of these by hacking the functor + multi tensor harness with persistence
// logic, but keeping it simple for now
auto ret = at::empty({1}, output.options());
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
auto stream = at::cuda::getCurrentCUDAStream();
cleanup<<<per_tensor ? ntensors : 1, 512, 0, stream>>>(
output.data_ptr<float>(), per_tensor ? output_per_tensor.data_ptr<float>() : nullptr,
ret.data_ptr<float>(), per_tensor ? ret_per_tensor.data_ptr<float>() : nullptr, per_tensor,
const OptionalCUDAGuard device_guard(device_id);
cleanup<<<per_tensor ? tensor_lists[0].size() : 1, 512, 0, stream>>>(
reinterpret_cast<float *>(output.data.dptr),
per_tensor ? reinterpret_cast<float *>(output_per_tensor.data.dptr) : nullptr,
reinterpret_cast<float *>(ret.data.dptr),
per_tensor ? reinterpret_cast<float *>(ret_per_tensor.data.dptr) : nullptr, per_tensor,
max_chunks_per_tensor);
}
} // namespace multi_tensor_l2norm
} // namespace transformer_engine
void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list,
NVTETensor output, NVTETensor output_per_tensor, NVTETensor ret,
NVTETensor ret_per_tensor, int per_tensor,
int max_chunks_per_tensor, const int device_id,
cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_l2norm_cuda);
using namespace transformer_engine;
multi_tensor_l2norm::multi_tensor_l2norm_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
*reinterpret_cast<Tensor *>(output), *reinterpret_cast<Tensor *>(output_per_tensor),
*reinterpret_cast<Tensor *>(ret), *reinterpret_cast<Tensor *>(ret_per_tensor), per_tensor,
max_chunks_per_tensor, device_id, stream);
}
return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor);
void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, NVTETensor output,
NVTETensor output_per_tensor, NVTETensor ret,
NVTETensor ret_per_tensor, NVTETensor inv_scale,
int per_tensor, int max_chunks_per_tensor,
const int device_id, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_unscale_l2norm_cuda);
using namespace transformer_engine;
multi_tensor_l2norm::multi_tensor_unscale_l2norm_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
*reinterpret_cast<Tensor *>(output), *reinterpret_cast<Tensor *>(output_per_tensor),
*reinterpret_cast<Tensor *>(ret), *reinterpret_cast<Tensor *>(ret_per_tensor),
*reinterpret_cast<Tensor *>(inv_scale), per_tensor, max_chunks_per_tensor, device_id, stream);
}
......@@ -5,17 +5,62 @@
************************************************************************/
#pragma once
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <assert.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <transformer_engine/multi_tensor.h>
#include <transformer_engine/transformer_engine.h>
#include "common/common.h"
#include "../common.h"
// This header is the one-stop shop for all your multi-tensor apply needs.
// Change device if needed.
class OptionalCUDAGuard {
public:
explicit OptionalCUDAGuard(int new_device) {
if (new_device < 0) return;
int current_device;
NVTE_CHECK_CUDA(cudaGetDevice(&current_device));
if (new_device != current_device) {
NVTE_CHECK_CUDA(cudaSetDevice(new_device));
device_changed_ = true;
prev_device_ = current_device;
}
}
OptionalCUDAGuard(const OptionalCUDAGuard &) = delete;
OptionalCUDAGuard &operator=(const OptionalCUDAGuard &) = delete;
OptionalCUDAGuard(OptionalCUDAGuard &&other) noexcept
: prev_device_(other.prev_device_), device_changed_(other.device_changed_) {
other.device_changed_ = false;
}
OptionalCUDAGuard &operator=(OptionalCUDAGuard &&other) noexcept {
if (this != &other) {
if (device_changed_) {
cudaSetDevice(prev_device_);
}
prev_device_ = other.prev_device_;
device_changed_ = other.device_changed_;
other.device_changed_ = false;
}
return *this;
}
~OptionalCUDAGuard() {
if (device_changed_) {
NVTE_CHECK_CUDA(cudaSetDevice(prev_device_));
}
}
private:
int prev_device_;
bool device_changed_ = false;
};
// TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson)
constexpr int depth_to_max_tensors[6] = {110, 64, 48, 36, 30, 24};
constexpr int depth_to_max_blocks[6] = {320, 320, 320, 320, 320, 320};
......@@ -46,61 +91,39 @@ __global__ void multi_tensor_apply_kernel(int64_t chunk_size, volatile int *noop
}
template <int depth, bool USE_FP8 = false, typename T, typename... ArgTypes>
void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor &noop_flag,
const std::vector<std::vector<at::Tensor>> &tensor_lists, T callable,
ArgTypes... args) {
void multi_tensor_apply(int64_t block_size, int64_t chunk_size,
const transformer_engine::Tensor &noop_flag,
std::vector<std::vector<transformer_engine::Tensor *>> tensor_lists,
T callable, const int device_id, cudaStream_t stream, ArgTypes... args) {
const size_t num_tensor_lists = tensor_lists.size();
const size_t num_tensors_per_list = tensor_lists[0].size();
if constexpr (USE_FP8) {
TORCH_CHECK(tensor_lists.size() == depth + 3,
NVTE_CHECK(num_tensor_lists == depth + 3,
"tensor_lists.size() != depth + 3, tensor_lists should have 3 more tensors (scale, "
"amax, scale_inv) for fp8");
} else {
TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
}
int len0 = tensor_lists[0].size();
TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");
auto ref_device = tensor_lists[0][0].device();
TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda");
for (int l = 0; l < depth; l++) { // No range-based for because I need indices
TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists");
for (int t = 0; t < tensor_lists[l].size(); t++) {
// TODO: Print which tensor fails.
bool contiguous_memory = tensor_lists[l][t].is_contiguous();
contiguous_memory =
(contiguous_memory || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast) ||
tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast3d));
TORCH_CHECK(contiguous_memory, "A tensor was not contiguous.");
TORCH_CHECK(tensor_lists[l][t].device() == ref_device,
"A tensor was not on the same device as the first tensor");
TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch");
NVTE_CHECK(num_tensor_lists == depth, "tensor_lists.size() != depth");
}
}
if constexpr (USE_FP8) {
TORCH_CHECK(tensor_lists[depth].size() == len0 && tensor_lists[depth + 1].size() == len0,
"Size mismatch among tensor lists");
}
int ntensors = tensor_lists[0].size();
TensorListMetadata<depth, USE_FP8> tl;
const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0]));
auto stream = at::cuda::getCurrentCUDAStream();
const OptionalCUDAGuard device_guard(device_id);
tl.start_tensor_this_launch = 0;
int loc_block_info = 0;
int loc_tensor_info = 0;
for (int t = 0; t < ntensors; t++) {
tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
for (int t = 0; t < num_tensors_per_list; t++) {
tl.sizes[loc_tensor_info] = tensor_lists[0][t]->numel();
for (int d = 0; d < depth; d++)
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t]->data.dptr;
if constexpr (USE_FP8) {
for (int i = 0; i < 3; i++)
tl.fp8_meta_addresses[i][loc_tensor_info] = tensor_lists[depth + i][t].data_ptr();
tl.fp8_meta_addresses[i][loc_tensor_info] = tensor_lists[depth + i][t]->data.dptr;
}
loc_tensor_info++;
auto chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
auto chunks_this_tensor = (tensor_lists[0][t]->numel() + chunk_size - 1) / chunk_size;
for (auto chunk = 0; chunk < chunks_this_tensor; chunk++) {
tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
......@@ -110,12 +133,12 @@ void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor
bool tensors_full =
(loc_tensor_info == depth_to_max_tensors[depth - 1] && chunk == chunks_this_tensor - 1);
bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]);
bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1);
bool last_chunk = (t == num_tensors_per_list - 1 && chunk == chunks_this_tensor - 1);
if (tensors_full || blocks_full || last_chunk) {
multi_tensor_apply_kernel<<<loc_block_info, block_size, 0, stream>>>(
chunk_size, noop_flag.data_ptr<int>(), tl, callable, args...);
chunk_size, reinterpret_cast<int *>(noop_flag.data.dptr), tl, callable, args...);
AT_CUDA_CHECK(cudaGetLastError());
NVTE_CHECK_CUDA(cudaGetLastError());
// Reset. The control flow possibilities here make my brain hurt.
loc_block_info = 0;
......
......@@ -4,19 +4,20 @@
* See LICENSE for license information.
************************************************************************/
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include <cuda_fp8.h>
// Stringstream is a big hammer, but I want to rely on operator<< for dtype.
#include <transformer_engine/multi_tensor.h>
#include <transformer_engine/transformer_engine.h>
#include <iostream>
#include <sstream>
#include "../utils.cuh"
#include "multi_tensor_apply.cuh"
#include "type_shim.h"
namespace transformer_engine {
namespace multi_tensor_scale {
#define BLOCK_SIZE 512
#define ILP 4
......@@ -66,7 +67,7 @@ struct ScaleFunctor {
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
r_out[ii] = static_cast<float>(r_in[ii]) * scale;
finite = finite && isfinite(r_in[ii]);
finite = finite && isfinite(static_cast<float>(r_in[ii]));
}
// store
load_store(out, r_out, i_start, 0);
......@@ -76,7 +77,7 @@ struct ScaleFunctor {
for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
r_in[ii] = 0;
r_in[ii] = 0.f;
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) r_in[ii] = in[i];
}
......@@ -88,7 +89,7 @@ struct ScaleFunctor {
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
r_out[ii] = static_cast<float>(r_in[ii]) * scale;
finite = finite && isfinite(r_in[ii]);
finite = finite && isfinite(static_cast<float>(r_in[ii]));
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
......@@ -101,20 +102,29 @@ struct ScaleFunctor {
}
};
void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, float scale) {
using namespace at;
// The output (downscaled) type is always float.
// If build times suffer, think about where to put this dispatch,
// and what logic should be moved out of multi_tensor_apply.
DISPATCH_FLOAT_HALF_AND_BFLOAT(
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_scale_cuda",
DISPATCH_FLOAT_HALF_AND_BFLOAT(
tensor_lists[1][0].scalar_type(), 1, "multi_tensor_scale_cuda",
void multi_tensor_scale_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor *>> tensor_lists, float scale,
const int device_id, cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[0][0]->dtype(), p_in_type,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[1][0]->dtype(), g_in_type,
multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
ScaleFunctor<scalar_t_0, scalar_t_1>(), scale);))
AT_CUDA_CHECK(cudaGetLastError());
ScaleFunctor<p_in_type, g_in_type>(), device_id, stream, scale);))
NVTE_CHECK_CUDA(cudaGetLastError());
}
} // namespace multi_tensor_scale
} // namespace transformer_engine
void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list,
float scale, const int device_id, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_scale_cuda);
using namespace transformer_engine;
// AT_CUDA_CHECK(cudaDeviceSynchronize());
multi_tensor_scale::multi_tensor_scale_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), scale, device_id,
stream);
}
......@@ -4,14 +4,16 @@
* See LICENSE for license information.
************************************************************************/
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <assert.h>
#include <cuda_fp8.h>
#include <transformer_engine/multi_tensor.h>
#include <transformer_engine/transformer_engine.h>
#include "../utils.cuh"
#include "multi_tensor_apply.cuh"
#include "type_shim.h"
namespace transformer_engine {
namespace multi_tensor_sgd {
#define BLOCK_SIZE 512
#define ILP 4
......@@ -54,9 +56,9 @@ struct SGDFunctor {
T_weight* mom_in = reinterpret_cast<T_weight*>(tl.addresses[2][tensor_loc]);
mom_in += chunk_idx * chunk_size;
at::Half* model_weights_out = nullptr;
fp16* model_weights_out = nullptr;
if (N == 4) {
model_weights_out = (at::Half*)tl.addresses[3][tensor_loc];
model_weights_out = reinterpret_cast<fp16*>(tl.addresses[3][tensor_loc]);
model_weights_out += chunk_idx * chunk_size;
}
......@@ -112,7 +114,7 @@ struct SGDFunctor {
weight_in[i] += (-lr * incoming_grads[ii]);
// if necessary, write out an fp16 copy of the weights
if (N == 4) model_weights_out[i] = static_cast<at::Half>(weight_in[i]);
if (N == 4) model_weights_out[i] = static_cast<fp16>(weight_in[i]);
// also write out the new momentum
if (momentum != 0.f) mom_in[i] = incoming_moms[ii];
......@@ -122,23 +124,23 @@ struct SGDFunctor {
}
};
void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, float wd,
float momentum, float dampening, float lr, bool nesterov, bool first_run,
bool wd_after_momentum, float scale) {
auto num_tensors = tensor_lists.size();
auto grad_type = tensor_lists[0][0].scalar_type();
auto weight_type = tensor_lists[1][0].scalar_type();
if (num_tensors == 4) {
for (int i = 0; i < tensor_lists[3].size(); i++)
TORCH_CHECK(tensor_lists[3][i].scalar_type() == at::ScalarType::Half,
void multi_tensor_sgd_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor*>> tensor_lists, float wd, float momentum,
float dampening, float lr, bool nesterov, bool first_run,
bool wd_after_momentum, float scale, const int device_id,
cudaStream_t stream) {
const size_t num_tensor_lists = tensor_lists.size();
const size_t num_tensors_per_list = tensor_lists[0].size();
auto grad_type = tensor_lists[0][0]->dtype();
auto weight_type = tensor_lists[1][0]->dtype();
if (num_tensor_lists == 4) {
for (int i = 0; i < num_tensors_per_list; i++)
NVTE_CHECK(tensor_lists[3][i]->dtype() == DType::kFloat16,
"Additional output tensors should always be fp16.");
}
TORCH_CHECK(noop_flag.device() == tensor_lists[0][0].device(),
"expected noop flag to be on the same device as tensors");
// We have 3 possibilities to handle here, in terms of
// grad_type, param_type, momentum_type, requires_fp16_copy
// 1. fp16, fp16, fp16, No
......@@ -150,54 +152,51 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
// we don't want the majority of them.
// Case 1. fp16, fp16, fp16, No
if (grad_type == at::ScalarType::Half && weight_type == at::ScalarType::Half &&
num_tensors == 3) {
if (grad_type == DType::kFloat16 && weight_type == DType::kFloat16 && num_tensor_lists == 3) {
multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
SGDFunctor<3, at::Half, at::Half>(), wd, momentum, dampening, lr,
nesterov, first_run, wd_after_momentum, scale);
SGDFunctor<3, fp16, fp16>(), device_id, stream, wd, momentum, dampening,
lr, nesterov, first_run, wd_after_momentum, scale);
}
// Case 2. fp16, fp32, fp32, No
// else if (grad_type == at::ScalarType::Half &&
// weight_type == at::ScalarType::Float &&
// num_tensors == 3) {
// multi_tensor_apply<3>(
// BLOCK_SIZE,
// chunk_size,
// noop_flag,
// tensor_lists,
// SGDFunctor<3, at::Half, float>(),
// wd,
// momentum,
// dampening,
// lr,
// nesterov,
// first_run,
// wd_after_momentum);
// }
// Case 2. fp32, fp32, fp32, No
else if (grad_type == at::ScalarType::Float && // NOLINT(*)
weight_type == at::ScalarType::Float && num_tensors == 3) {
else if (grad_type == DType::kFloat32 && // NOLINT(*)
weight_type == DType::kFloat32 && num_tensor_lists == 3) {
multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
SGDFunctor<3, float, float>(), wd, momentum, dampening, lr, nesterov,
first_run, wd_after_momentum, scale);
SGDFunctor<3, float, float>(), device_id, stream, wd, momentum, dampening,
lr, nesterov, first_run, wd_after_momentum, scale);
}
// Case 3. fp16, fp32, fp32, Yes
else if (grad_type == at::ScalarType::Half && // NOLINT(*)
weight_type == at::ScalarType::Float && num_tensors == 4) {
else if (grad_type == DType::kFloat16 && // NOLINT(*)
weight_type == DType::kFloat32 && num_tensor_lists == 4) {
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
SGDFunctor<4, at::Half, float>(), wd, momentum, dampening, lr, nesterov,
first_run, wd_after_momentum, scale);
SGDFunctor<4, fp16, float>(), device_id, stream, wd, momentum, dampening,
lr, nesterov, first_run, wd_after_momentum, scale);
}
// Case 4. fp32, fp32, fp32, Yes
else if (grad_type == at::ScalarType::Float && // NOLINT(*)
weight_type == at::ScalarType::Float && num_tensors == 4) {
else if (grad_type == DType::kFloat32 && // NOLINT(*)
weight_type == DType::kFloat32 && num_tensor_lists == 4) {
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
SGDFunctor<4, float, float>(), wd, momentum, dampening, lr, nesterov,
first_run, wd_after_momentum, scale);
SGDFunctor<4, float, float>(), device_id, stream, wd, momentum, dampening,
lr, nesterov, first_run, wd_after_momentum, scale);
} else {
AT_ERROR("multi_tensor_sgd only supports some combinations of gradient & weight types. Given: ",
"gradient: ", grad_type, ", weight: ", weight_type, ", num_lists: ", num_tensors);
NVTE_ERROR("Unsupported combination of weight and gradient types.");
}
AT_CUDA_CHECK(cudaGetLastError());
NVTE_CHECK_CUDA(cudaGetLastError());
}
} // namespace multi_tensor_sgd
} // namespace transformer_engine
void nvte_multi_tensor_sgd_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor** tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list,
float wd, float momentum, float dampening, float lr, int nesterov,
int first_run, int wd_after_momentum, float scale,
const int device_id, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_sgd_cuda);
using namespace transformer_engine;
multi_tensor_sgd::multi_tensor_sgd_cuda(
chunk_size, *reinterpret_cast<Tensor*>(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), wd, momentum,
dampening, lr, nesterov, first_run, wd_after_momentum, scale, device_id, stream);
}
......@@ -49,11 +49,11 @@ std::string to_string(const DType type) {
std::string to_string(const NVTEScalingMode &mode) {
switch (mode) {
case NVTE_DELAYED_TENSOR_SCALING:
return "Delayed Tensor Scaling";
return "NVTE_DELAYED_TENSOR_SCALING";
case NVTE_MXFP8_1D_SCALING:
return "MXFP8 1D Scaling";
return "NVTE_MXFP8_1D_SCALING";
case NVTE_INVALID_SCALING:
return "Invalid Scaling";
return "NVTE_INVALID_SCALING";
}
return "Invalid Scaling";
}
......
......@@ -97,6 +97,43 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor)
return makeTransformerEngineTensor(tensor.data_ptr(), shape, dtype);
}
std::tuple<std::vector<transformer_engine::TensorWrapper>, std::vector<std::vector<NVTETensor>>,
std::vector<NVTETensor*>, size_t, size_t>
makeTransformerEngineTensorList(std::vector<std::vector<at::Tensor>> at_tensor_lists) {
size_t num_lists = at_tensor_lists.size();
NVTE_CHECK(num_lists > 0, "List of tensors is empty.");
size_t num_tensors = at_tensor_lists[0].size();
std::vector<std::vector<NVTETensor>> nvte_tensor_lists;
std::vector<NVTETensor*> nvte_tensor_list_ptrs;
std::vector<transformer_engine::TensorWrapper> tensorWrappers;
nvte_tensor_lists.reserve(num_lists);
nvte_tensor_list_ptrs.reserve(num_lists);
tensorWrappers.reserve(num_lists * num_tensors);
for (const auto& at_list : at_tensor_lists) {
NVTE_CHECK(at_list.size() == num_tensors, "Wrong number of tensors");
std::vector<NVTETensor> te_list;
te_list.reserve(num_tensors);
for (const auto& at_tensor : at_list) {
tensorWrappers.push_back(makeTransformerEngineTensor(at_tensor));
te_list.push_back(tensorWrappers.back().data());
}
nvte_tensor_lists.push_back(std::move(te_list));
}
for (auto& te_tensor_list : nvte_tensor_lists) {
nvte_tensor_list_ptrs.push_back(te_tensor_list.data());
}
return std::make_tuple(std::move(tensorWrappers), std::move(nvte_tensor_lists),
std::move(nvte_tensor_list_ptrs), num_lists, num_tensors);
}
transformer_engine::TensorWrapper makeTransformerEngineTensor(
void* data_ptr, const std::vector<size_t>& shape, const transformer_engine::DType type,
void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, std::vector<size_t> scale_inv_shape,
......
......@@ -30,6 +30,7 @@
#include <transformer_engine/fused_attn.h>
#include <transformer_engine/fused_rope.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/multi_tensor.h>
#include <transformer_engine/normalization.h>
#include <transformer_engine/padding.h>
#include <transformer_engine/permutation.h>
......@@ -219,6 +220,8 @@ transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid,
inline at::ScalarType GetATenDType(transformer_engine::DType t) {
switch (t) {
case transformer_engine::DType::kInt16:
return torch::kInt16;
case transformer_engine::DType::kInt32:
return torch::kInt32;
case transformer_engine::DType::kInt64:
......@@ -256,6 +259,8 @@ inline transformer_engine::DType GetTransformerEngineDType(at::ScalarType t) {
return transformer_engine::DType::kByte;
case torch::kByte:
return transformer_engine::DType::kByte;
case torch::kInt16:
return transformer_engine::DType::kInt16;
case torch::kInt32:
return transformer_engine::DType::kInt32;
case torch::kInt64:
......@@ -293,6 +298,10 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr,
transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor);
std::tuple<std::vector<transformer_engine::TensorWrapper>, std::vector<std::vector<NVTETensor>>,
std::vector<NVTETensor*>, size_t, size_t>
makeTransformerEngineTensorList(std::vector<std::vector<at::Tensor>> at_tensor_lists);
TensorWrapper makeTransformerEngineTensor(py::handle tensor, py::handle quantizer);
transformer_engine::TensorWrapper makeTransformerEngineTensor(
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, const float lr,
const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction,
const float weight_decay) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_adam_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists,
num_tensors, lr, beta1, beta2, epsilon, step, mode, bias_correction,
weight_decay, device_id, at::cuda::getCurrentCUDAStream());
}
void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode,
const int bias_correction, const float weight_decay) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_adam_param_remainder_cuda(
chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, lr, beta1,
beta2, epsilon, step, mode, bias_correction, weight_decay, device_id,
at::cuda::getCurrentCUDAStream());
}
void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, const float lr,
const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction,
const float weight_decay, DType fp8_dtype) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_adam_fp8_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(),
num_lists, num_tensors, lr, beta1, beta2, epsilon, step, mode,
bias_correction, weight_decay, static_cast<NVTEDType>(fp8_dtype),
device_id, at::cuda::getCurrentCUDAStream());
}
void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor lr, const float beta1, const float beta2,
const float epsilon, at::Tensor step, const int mode,
const int bias_correction, const float weight_decay,
at::Tensor inv_scale) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
auto lr_cu = makeTransformerEngineTensor(lr);
auto step_cu = makeTransformerEngineTensor(step);
auto inv_scale_cu = makeTransformerEngineTensor(inv_scale);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_adam_capturable_cuda(
chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors,
lr_cu.data(), beta1, beta2, epsilon, step_cu.data(), mode, bias_correction, weight_decay,
inv_scale_cu.data(), device_id, at::cuda::getCurrentCUDAStream());
}
void multi_tensor_adam_capturable_master_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor lr, const float beta1, const float beta2,
const float epsilon, at::Tensor step, const int mode,
const int bias_correction, const float weight_decay,
at::Tensor inv_scale) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
auto lr_cu = makeTransformerEngineTensor(lr);
auto step_cu = makeTransformerEngineTensor(step);
auto inv_scale_cu = makeTransformerEngineTensor(inv_scale);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_adam_capturable_master_cuda(
chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors,
lr_cu.data(), beta1, beta2, epsilon, step_cu.data(), mode, bias_correction, weight_decay,
inv_scale_cu.data(), device_id, at::cuda::getCurrentCUDAStream());
}
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
void multi_tensor_compute_scale_and_scale_inv_cuda(
int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
float max_fp8, bool force_pow_2_scales, float epsilon) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_compute_scale_and_scale_inv_cuda(
chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, max_fp8,
force_pow_2_scales, epsilon, device_id, at::cuda::getCurrentCUDAStream());
}
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false;
auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
auto output = at::zeros({320}, float_options);
at::Tensor output_per_tensor;
at::Tensor ret_per_tensor;
auto ret = at::empty({1}, output.options());
int ntensors = tensor_lists[0].size();
int max_chunks_per_tensor = -1;
if (per_tensor) {
for (int t = 0; t < ntensors; t++) {
int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
if (max_chunks_this_tensor > max_chunks_per_tensor)
max_chunks_per_tensor = max_chunks_this_tensor;
}
output_per_tensor = at::zeros({ntensors * max_chunks_per_tensor}, float_options);
ret_per_tensor = at::empty({ntensors}, float_options);
} else {
output_per_tensor = at::empty({0}, float_options);
ret_per_tensor = at::empty({0}, float_options);
}
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
auto output_cu = makeTransformerEngineTensor(output);
auto output_per_tensor_cu = makeTransformerEngineTensor(output_per_tensor);
auto ret_cu = makeTransformerEngineTensor(ret);
auto ret_per_tensor_cu = makeTransformerEngineTensor(ret_per_tensor);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_l2norm_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists,
num_tensors, output_cu.data(), output_per_tensor_cu.data(),
ret_cu.data(), ret_per_tensor_cu.data(), per_tensor,
max_chunks_per_tensor, device_id, at::cuda::getCurrentCUDAStream());
return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor);
}
std::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda(
int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor inv_scale, at::optional<bool> per_tensor_python) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false;
auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
auto output = at::zeros({320}, float_options);
at::Tensor output_per_tensor;
at::Tensor ret_per_tensor;
int ntensors = tensor_lists[0].size();
int max_chunks_per_tensor = -1;
// Create output tensors for multi scale L2 norm kernel.
if (per_tensor) {
for (int t = 0; t < ntensors; t++) {
int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
if (max_chunks_this_tensor > max_chunks_per_tensor)
max_chunks_per_tensor = max_chunks_this_tensor;
}
output_per_tensor = at::zeros({ntensors * max_chunks_per_tensor}, float_options);
ret_per_tensor = at::empty({ntensors}, float_options);
} else {
output_per_tensor = at::empty({0}, float_options);
ret_per_tensor = at::empty({0}, float_options);
}
auto ret = at::empty({1}, output.options());
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
auto output_cu = makeTransformerEngineTensor(output);
auto output_per_tensor_cu = makeTransformerEngineTensor(output_per_tensor);
auto ret_cu = makeTransformerEngineTensor(ret);
auto ret_per_tensor_cu = makeTransformerEngineTensor(ret_per_tensor);
auto inv_scale_cu = makeTransformerEngineTensor(inv_scale);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_unscale_l2norm_cuda(
chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors,
output_cu.data(), output_per_tensor_cu.data(), ret_cu.data(), ret_per_tensor_cu.data(),
inv_scale_cu.data(), per_tensor, max_chunks_per_tensor, device_id,
at::cuda::getCurrentCUDAStream());
return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment