Unverified Commit c972f5a7 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[C][PyTorch] Move multi tensors kernels from PyTorch extensions to core (#1744)



* Move multi tensors kernels from PyTorch extensions to core
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add int16 type to core (for storing fp32 param remainders)
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix core build
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* same fix to scale
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix perf, memory, vars
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Re-add device guard for multi-device
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix junk output dtype for non-per tensor
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fixes for test and upgrade mcore version
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix core tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent e17fab14
...@@ -17,7 +17,7 @@ fi ...@@ -17,7 +17,7 @@ fi
# Download Megatron-LM if needed # Download Megatron-LM if needed
if [ ! -d "${MCORE_PATH}" ]; then if [ ! -d "${MCORE_PATH}" ]; then
pushd $(dirname ${MCORE_PATH}) pushd $(dirname ${MCORE_PATH})
git clone -b core_r0.9.0 https://github.com/NVIDIA/Megatron-LM.git Megatron-LM git clone -b core_r0.12.0 https://github.com/NVIDIA/Megatron-LM.git Megatron-LM
popd popd
fi fi
......
...@@ -46,6 +46,7 @@ struct BytesToType<8> { ...@@ -46,6 +46,7 @@ struct BytesToType<8> {
}; };
using byte = uint8_t; using byte = uint8_t;
using int16 = int16_t;
using int32 = int32_t; using int32 = int32_t;
using int64 = int64_t; using int64 = int64_t;
using fp32 = float; using fp32 = float;
...@@ -58,6 +59,7 @@ using fp8e8m0 = uint8_t; ...@@ -58,6 +59,7 @@ using fp8e8m0 = uint8_t;
template <typename T> template <typename T>
struct TypeInfo{ struct TypeInfo{
using types = std::tuple<byte, using types = std::tuple<byte,
int16,
int32, int32,
int64, int64,
fp32, fp32,
......
...@@ -159,7 +159,7 @@ def test_multi_tensor_l2norm(input_size_pair, applier, repeat, in_type, per_tens ...@@ -159,7 +159,7 @@ def test_multi_tensor_l2norm(input_size_pair, applier, repeat, in_type, per_tens
normab = torch.cat((a.norm().view(1), b.norm().view(1))) normab = torch.cat((a.norm().view(1), b.norm().view(1)))
norm_per_tensor = norm_per_tensor.view(-1, 2) norm_per_tensor = norm_per_tensor.view(-1, 2)
else: else:
norm, _ = applier(tex.multi_tensor_l2norm, overflow_buf, [in_list], True) norm, _ = applier(tex.multi_tensor_l2norm, overflow_buf, [in_list], False)
reference = torch.full( reference = torch.full(
[(sizea + sizeb) * repeat], val, dtype=torch.float32, device=device [(sizea + sizeb) * repeat], val, dtype=torch.float32, device=device
......
...@@ -53,6 +53,11 @@ list(APPEND transformer_engine_SOURCES ...@@ -53,6 +53,11 @@ list(APPEND transformer_engine_SOURCES
cudnn_utils.cpp cudnn_utils.cpp
transformer_engine.cpp transformer_engine.cpp
common.cu common.cu
multi_tensor/adam.cu
multi_tensor/compute_scale.cu
multi_tensor/l2norm.cu
multi_tensor/scale.cu
multi_tensor/sgd.cu
transpose/cast_transpose.cu transpose/cast_transpose.cu
transpose/transpose.cu transpose/transpose.cu
transpose/cast_transpose_fusion.cu transpose/cast_transpose_fusion.cu
...@@ -163,6 +168,11 @@ target_include_directories(transformer_engine PRIVATE ...@@ -163,6 +168,11 @@ target_include_directories(transformer_engine PRIVATE
set_source_files_properties(fused_softmax/scaled_masked_softmax.cu set_source_files_properties(fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu fused_softmax/scaled_aligned_causal_masked_softmax.cu
multi_tensor/adam.cu
multi_tensor/compute_scale.cu
multi_tensor/l2norm.cu
multi_tensor/scale.cu
multi_tensor/sgd.cu
PROPERTIES PROPERTIES
COMPILE_OPTIONS "--use_fast_math") COMPILE_OPTIONS "--use_fast_math")
option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --use_fast_math option" OFF) option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --use_fast_math option" OFF)
......
...@@ -193,4 +193,16 @@ bool is_supported_by_CC_100() { ...@@ -193,4 +193,16 @@ bool is_supported_by_CC_100() {
return deviceComputeCapability >= 100; return deviceComputeCapability >= 100;
} }
std::vector<std::vector<Tensor *>> convert_tensor_array(NVTETensor **nvte_tensors,
size_t outer_size, size_t inner_size) {
std::vector<std::vector<Tensor *>> ret;
for (size_t i = 0; i < outer_size; ++i) {
ret.emplace_back();
for (size_t j = 0; j < inner_size; ++j) {
ret.back().push_back(reinterpret_cast<Tensor *>(nvte_tensors[i][j]));
}
}
return ret;
}
} // namespace transformer_engine } // namespace transformer_engine
...@@ -238,6 +238,7 @@ constexpr T DIVUP(const T &x, const T &y) { ...@@ -238,6 +238,7 @@ constexpr T DIVUP(const T &x, const T &y) {
} }
using byte = uint8_t; using byte = uint8_t;
using int16 = int16_t;
using int32 = int32_t; using int32 = int32_t;
using int64 = int64_t; using int64 = int64_t;
using fp32 = float; using fp32 = float;
...@@ -260,6 +261,7 @@ constexpr inline const char *type_name() noexcept; ...@@ -260,6 +261,7 @@ constexpr inline const char *type_name() noexcept;
return #T; \ return #T; \
} }
TRANSFORMER_ENGINE_TYPE_NAME(uint8_t) TRANSFORMER_ENGINE_TYPE_NAME(uint8_t)
TRANSFORMER_ENGINE_TYPE_NAME(int16_t)
TRANSFORMER_ENGINE_TYPE_NAME(int32_t) TRANSFORMER_ENGINE_TYPE_NAME(int32_t)
TRANSFORMER_ENGINE_TYPE_NAME(int64_t) TRANSFORMER_ENGINE_TYPE_NAME(int64_t)
TRANSFORMER_ENGINE_TYPE_NAME(float) TRANSFORMER_ENGINE_TYPE_NAME(float)
...@@ -306,7 +308,7 @@ struct TypeExtrema { ...@@ -306,7 +308,7 @@ struct TypeExtrema {
template <typename T> template <typename T>
struct TypeInfo { struct TypeInfo {
using types = std::tuple<byte, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2>; using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2>;
template <typename U, DType current> template <typename U, DType current>
struct Helper { struct Helper {
...@@ -343,6 +345,10 @@ struct TypeInfo { ...@@ -343,6 +345,10 @@ struct TypeInfo {
using type = unsigned char; \ using type = unsigned char; \
{ __VA_ARGS__ } \ { __VA_ARGS__ } \
} break; \ } break; \
case DType::kInt16: { \
using type = int16_t; \
{ __VA_ARGS__ } \
} break; \
case DType::kInt32: { \ case DType::kInt32: { \
using type = int32_t; \ using type = int32_t; \
{ __VA_ARGS__ } \ { __VA_ARGS__ } \
...@@ -576,6 +582,9 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, ...@@ -576,6 +582,9 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
bool is_supported_by_CC_100(); bool is_supported_by_CC_100();
std::vector<std::vector<Tensor *>> convert_tensor_array(NVTETensor **nvte_tensors,
size_t outer_size, size_t inner_size);
} // namespace transformer_engine } // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_COMMON_H_ #endif // TRANSFORMER_ENGINE_COMMON_COMMON_H_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file multi_tensor.h
* \brief Functions handling multi tensor kernels.
*/
#ifndef TRANSFORMER_ENGINE_MULTI_TENSOR_H_
#define TRANSFORMER_ENGINE_MULTI_TENSOR_H_
#include "transformer_engine.h"
#ifdef __cplusplus
extern "C" {
#endif
void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list,
NVTETensor output, NVTETensor output_per_tensor, NVTETensor ret,
NVTETensor ret_per_tensor, int per_tensor,
int max_chunks_per_tensor, const int device_id,
cudaStream_t stream);
void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, NVTETensor output,
NVTETensor output_per_tensor, NVTETensor ret,
NVTETensor ret_per_tensor, NVTETensor inv_scale,
int per_tensor, int max_chunks_per_tensor,
const int device_id, cudaStream_t stream);
void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list,
const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode,
const int bias_correction, const float weight_decay,
const int device_id, cudaStream_t stream);
void nvte_multi_tensor_adam_param_remainder_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode, const int bias_correction,
const float weight_decay, const int device_id, cudaStream_t stream);
void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag,
NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, const float lr,
const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction,
const float weight_decay, const NVTEDType fp8_dtype,
const int device_id, cudaStream_t stream);
void nvte_multi_tensor_adam_capturable_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2,
const float epsilon, NVTETensor step, const int mode, const int bias_correction,
const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream);
void nvte_multi_tensor_adam_capturable_master_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2,
const float epsilon, NVTETensor step, const int mode, const int bias_correction,
const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream);
void nvte_multi_tensor_sgd_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list,
float wd, float momentum, float dampening, float lr, int nesterov,
int first_run, int wd_after_momentum, float scale,
const int device_id, cudaStream_t stream);
void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list,
float scale, const int device_id, cudaStream_t stream);
void nvte_multi_tensor_compute_scale_and_scale_inv_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, float max_fp8, int force_pow_2_scales, float epsilon,
const int device_id, cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
#endif // TRANSFORMER_ENGINE_MULTI_TENSOR_H_
...@@ -23,14 +23,15 @@ extern "C" { ...@@ -23,14 +23,15 @@ extern "C" {
*/ */
enum NVTEDType { enum NVTEDType {
kNVTEByte = 0, /*!< Byte */ kNVTEByte = 0, /*!< Byte */
kNVTEInt32 = 1, /*!< 32-bit integer */ kNVTEInt16 = 1, /*!< 16-bit integer */
kNVTEInt64 = 2, /*!< 64-bit integer */ kNVTEInt32 = 2, /*!< 32-bit integer */
kNVTEFloat32 = 3, /*!< 32-bit float */ kNVTEInt64 = 3, /*!< 64-bit integer */
kNVTEFloat16 = 4, /*!< 16-bit float (E5M10) */ kNVTEFloat32 = 4, /*!< 32-bit float */
kNVTEBFloat16 = 5, /*!< 16-bit bfloat (E8M7) */ kNVTEFloat16 = 5, /*!< 16-bit float (E5M10) */
kNVTEFloat8E4M3 = 6, /*!< 8-bit float (E4M3) */ kNVTEBFloat16 = 6, /*!< 16-bit bfloat (E8M7) */
kNVTEFloat8E5M2 = 7, /*!< 8-bit float (E5M2) */ kNVTEFloat8E4M3 = 7, /*!< 8-bit float (E4M3) */
kNVTEFloat8E8M0 = 8, /*!< 8-bit float (E8M0) */ kNVTEFloat8E5M2 = 8, /*!< 8-bit float (E5M2) */
kNVTEFloat8E8M0 = 9, /*!< 8-bit float (E8M0) */
kNVTENumTypes /*!< Number of supported types */ kNVTENumTypes /*!< Number of supported types */
}; };
...@@ -373,14 +374,15 @@ namespace transformer_engine { ...@@ -373,14 +374,15 @@ namespace transformer_engine {
*/ */
enum class DType { enum class DType {
kByte = 0, kByte = 0,
kInt32 = 1, kInt16 = 1,
kInt64 = 2, kInt32 = 2,
kFloat32 = 3, kInt64 = 3,
kFloat16 = 4, kFloat32 = 4,
kBFloat16 = 5, kFloat16 = 5,
kFloat8E4M3 = 6, kBFloat16 = 6,
kFloat8E5M2 = 7, kFloat8E4M3 = 7,
kFloat8E8M0 = 8, kFloat8E5M2 = 8,
kFloat8E8M0 = 9,
kNumTypes kNumTypes
}; };
......
...@@ -4,19 +4,16 @@ ...@@ -4,19 +4,16 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <cuda_fp8.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h> #include <assert.h>
#include <cuda_fp8.h>
#include <transformer_engine/multi_tensor.h>
#include <transformer_engine/transformer_engine.h>
#include "common/utils.cuh" #include "../utils.cuh"
#include "multi_tensor_apply.cuh" #include "multi_tensor_apply.cuh"
#include "type_shim.h"
namespace transformer_engine {
namespace multi_tensor_adam {
#define BLOCK_SIZE 512 #define BLOCK_SIZE 512
#define ILP 4 #define ILP 4
...@@ -30,7 +27,6 @@ typedef enum { ...@@ -30,7 +27,6 @@ typedef enum {
using MATH_T = float; using MATH_T = float;
using fp8e4m3 = __nv_fp8_e4m3; using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2; using fp8e5m2 = __nv_fp8_e5m2;
using transformer_engine::DType;
template <typename T> template <typename T>
struct is_fp8 : std::false_type {}; struct is_fp8 : std::false_type {};
...@@ -576,12 +572,13 @@ struct AdamCapturableMasterFunctor { ...@@ -576,12 +572,13 @@ struct AdamCapturableMasterFunctor {
} }
}; };
void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, const float lr, std::vector<std::vector<Tensor *>> tensor_lists, const float lr,
const float beta1, const float beta2, const float epsilon, const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction, const int step, const int mode, const int bias_correction,
const float weight_decay) { const float weight_decay, const int device_id, cudaStream_t stream) {
using namespace at; const size_t num_tensor_lists = tensor_lists.size();
const size_t num_tensors_per_list = tensor_lists[0].size();
// Handle bias correction mode // Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f; float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
...@@ -592,10 +589,10 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, ...@@ -592,10 +589,10 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
size_t max_size = 0; size_t max_size = 0;
bool requires_64bit_indexing = false; bool requires_64bit_indexing = false;
for (auto it = tensor_lists.begin(); it != tensor_lists.end(); it++) { for (size_t i = 0; i < num_tensor_lists; i++) {
for (auto it2 = it->begin(); it2 != it->end(); it2++) { for (size_t j = 0; j < num_tensors_per_list; j++) {
if (it2->numel() > max_size) { if (tensor_lists[i][j]->numel() > max_size) {
max_size = it2->numel(); max_size = tensor_lists[i][j]->numel();
if (max_size >= INT_MAX) { if (max_size >= INT_MAX) {
requires_64bit_indexing = true; requires_64bit_indexing = true;
break; break;
...@@ -607,69 +604,70 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, ...@@ -607,69 +604,70 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
} }
} }
const auto g_in_type = tensor_lists[0][0].scalar_type(); const auto g_in_type_te = tensor_lists[0][0]->dtype();
const auto p_in_type = tensor_lists[1][0].scalar_type(); const auto p_in_type_te = tensor_lists[1][0]->dtype();
auto tl_size = tensor_lists.size();
// case 4: g, p, m, v // case 4: g, p, m, v
// case 5: g, p, m, v, p_master // case 5: g, p, m, v, p_master
TORCH_CHECK(tl_size == 4 || tl_size == 5, "tensor list must contain 4 or 5"); NVTE_CHECK(num_tensor_lists == 4 || num_tensor_lists == 5, "tensor list must contain 4 or 5");
if (requires_64bit_indexing) { if (requires_64bit_indexing) {
if (tl_size == 4) { if (num_tensor_lists == 4) {
// Assume single type across p,g,m1,m2 now // Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
p_in_type, 0, "adam", p_in_type_te, p_in_type,
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
g_in_type, 1, "adam", g_in_type_te, g_in_type,
multi_tensor_apply<4>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, multi_tensor_apply<4>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag,
tensor_lists, tensor_lists,
AdamFunctor<scalar_t_0, scalar_t_1, float, int64_t>(), beta1, AdamFunctor<p_in_type, g_in_type, float, int64_t>(), device_id,
beta2, bias_correction1, bias_correction2, epsilon, lr, stream, beta1, beta2, bias_correction1, bias_correction2,
(adamMode_t)mode, weight_decay);)); epsilon, lr, (adamMode_t)mode, weight_decay);));
} else { } else {
// g, p, m, v, p_master // g, p, m, v, p_master
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
p_in_type, 0, "adam", p_in_type_te, p_in_type,
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
g_in_type, 1, "adam", g_in_type_te, g_in_type,
multi_tensor_apply<5>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, multi_tensor_apply<5>(
tensor_lists, (int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists,
AdamFunctorMaster<scalar_t_0, scalar_t_1, float, int64_t>(), AdamFunctorMaster<p_in_type, g_in_type, float, int64_t>(), device_id, stream,
beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode,
(adamMode_t)mode, weight_decay);)); weight_decay);));
} }
} else { } else {
if (tl_size == 4) { if (num_tensor_lists == 4) {
// Assume single type across p,g,m1,m2 now // Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
p_in_type, 0, "adam", p_in_type_te, p_in_type,
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
g_in_type, 1, "adam", g_in_type_te, g_in_type,
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamFunctor<scalar_t_0, scalar_t_1, float, int32_t>(), beta1, AdamFunctor<p_in_type, g_in_type, float, int32_t>(), device_id,
beta2, bias_correction1, bias_correction2, epsilon, lr, stream, beta1, beta2, bias_correction1, bias_correction2,
(adamMode_t)mode, weight_decay);)); epsilon, lr, (adamMode_t)mode, weight_decay);));
} else { } else {
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
p_in_type, 0, "adam", p_in_type_te, p_in_type,
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
g_in_type, 1, "adam", g_in_type_te, g_in_type,
multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamFunctorMaster<scalar_t_0, scalar_t_1, float, int32_t>(), AdamFunctorMaster<p_in_type, g_in_type, float, int32_t>(),
beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, device_id, stream, beta1, beta2, bias_correction1,
(adamMode_t)mode, weight_decay);)); bias_correction2, epsilon, lr, (adamMode_t)mode,
weight_decay);));
} }
} }
AT_CUDA_CHECK(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
} }
void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag, void multi_tensor_adam_param_remainder_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<Tensor *>> tensor_lists,
const float lr, const float beta1, const float beta2, const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode, const float epsilon, const int step, const int mode,
const int bias_correction, const float weight_decay) { const int bias_correction, const float weight_decay,
using namespace at; const int device_id, cudaStream_t stream) {
const size_t num_tensor_lists = tensor_lists.size();
// Handle bias correction mode // Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f; float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
...@@ -678,34 +676,34 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag ...@@ -678,34 +676,34 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag
bias_correction2 = 1 - std::pow(beta2, step); bias_correction2 = 1 - std::pow(beta2, step);
} }
const auto g_in_type = tensor_lists[0][0].scalar_type(); const auto g_in_type_te = tensor_lists[0][0]->dtype();
const auto p_in_type = tensor_lists[1][0].scalar_type(); const auto p_in_type_te = tensor_lists[1][0]->dtype();
auto tl_size = tensor_lists.size();
// case 5: g, p, m, v, p_master // case 5: g, p, m, v, p_master
TORCH_CHECK(tl_size == 5, "tensor list must contain 5"); NVTE_CHECK(num_tensor_lists == 5, "tensor list must contain 5");
TORCH_CHECK(p_in_type == at::ScalarType::BFloat16, NVTE_CHECK(p_in_type_te == DType::kBFloat16,
"Adam with BF16 param remainders requires BF16 params"); "Adam with BF16 param remainders requires BF16 params");
// g, p, m, v, p_master // g, p, m, v, p_master
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
p_in_type, 0, "adam", TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( g_in_type_te, g_in_type,
g_in_type, 1, "adam", multi_tensor_apply<5>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists,
multi_tensor_apply<5>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, AdamFunctorMasterParamRemainder<g_in_type, float, int64_t>(), device_id,
AdamFunctorMasterParamRemainder<scalar_t_1, float, int64_t>(), stream, beta1, beta2, bias_correction1, bias_correction2, epsilon, lr,
beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, weight_decay););
(adamMode_t)mode, weight_decay);));
NVTE_CHECK_CUDA(cudaGetLastError());
AT_CUDA_CHECK(cudaGetLastError());
} }
void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag, void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, const float lr, std::vector<std::vector<Tensor *>> tensor_lists, const float lr,
const float beta1, const float beta2, const float epsilon, const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction, const int step, const int mode, const int bias_correction,
const float weight_decay, DType fp8_dtype) { const float weight_decay, const DType fp8_dtype,
using namespace at; const int device_id, cudaStream_t stream) {
const size_t num_tensor_lists = tensor_lists.size();
const size_t num_tensors_per_list = tensor_lists[0].size();
// Handle bias correction mode // Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f; float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
...@@ -716,10 +714,10 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag, ...@@ -716,10 +714,10 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag,
size_t max_size = 0; size_t max_size = 0;
bool requires_64bit_indexing = false; bool requires_64bit_indexing = false;
for (auto it = tensor_lists.begin(); it != tensor_lists.end(); it++) { for (size_t i = 0; i < num_tensor_lists; i++) {
for (auto it2 = it->begin(); it2 != it->end(); it2++) { for (size_t j = 0; j < num_tensors_per_list; j++) {
if (it2->numel() > max_size) { if (tensor_lists[i][j]->numel() > max_size) {
max_size = it2->numel(); max_size = tensor_lists[i][j]->numel();
if (max_size >= INT_MAX) { if (max_size >= INT_MAX) {
requires_64bit_indexing = true; requires_64bit_indexing = true;
break; break;
...@@ -731,66 +729,147 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag, ...@@ -731,66 +729,147 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag,
} }
} }
const auto g_in_type = tensor_lists[0][0].scalar_type(); const auto g_in_type_te = tensor_lists[0][0]->dtype();
auto tl_size = tensor_lists.size();
// case 8: g, p_fp8, m, v, p_master, scale, amax, scale_inv // case 8: g, p_fp8, m, v, p_master, scale, amax, scale_inv
TORCH_CHECK(tl_size == 8, "tensor list must contain 8 tensors"); NVTE_CHECK(num_tensor_lists == 8, "tensor list must contain 8 tensors");
if (requires_64bit_indexing) { if (requires_64bit_indexing) {
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
fp8_dtype, FP8_T, fp8_dtype, FP8_T,
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
g_in_type, 0, "adam", g_in_type_te, g_in_type,
multi_tensor_apply<5, true>( multi_tensor_apply<5, true>(
(int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, (int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists,
AdamFunctorMaster<FP8_T, scalar_t_0, float, int64_t>(), beta1, beta2, AdamFunctorMaster<FP8_T, g_in_type, float, int64_t>(), device_id, stream, beta1,
bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, weight_decay);)); beta2, bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode,
weight_decay);));
} else { } else {
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
fp8_dtype, FP8_T, fp8_dtype, FP8_T,
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
g_in_type, 0, "adam", g_in_type_te, g_in_type,
multi_tensor_apply<5, true>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, multi_tensor_apply<5, true>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamFunctorMaster<FP8_T, scalar_t_0, float, int32_t>(), AdamFunctorMaster<FP8_T, g_in_type, float, int32_t>(),
beta1, beta2, bias_correction1, bias_correction2, epsilon, device_id, stream, beta1, beta2, bias_correction1,
lr, (adamMode_t)mode, weight_decay);)); bias_correction2, epsilon, lr, (adamMode_t)mode,
weight_decay);));
} }
AT_CUDA_CHECK(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
} }
void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag, void multi_tensor_adam_capturable_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<Tensor *>> tensor_lists, Tensor lr,
at::Tensor lr, const float beta1, const float beta2, const float beta1, const float beta2, const float epsilon,
const float epsilon, at::Tensor step, const int mode, Tensor step, const int mode, const int bias_correction,
const int bias_correction, const float weight_decay, const float weight_decay, Tensor inv_scale,
at::Tensor inv_scale) { const int device_id, cudaStream_t stream) {
using namespace at; TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[0][0]->dtype(), dtype,
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
tensor_lists[0][0].scalar_type(), 0, "adam",
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamCapturableFunctor<scalar_t_0, float>(), beta1, beta2, AdamCapturableFunctor<dtype, float>(), device_id, stream, beta1, beta2,
step.data_ptr<int>(), bias_correction, epsilon, lr.data_ptr<float>(), reinterpret_cast<int *>(step.data.dptr), bias_correction, epsilon,
(adamMode_t)mode, weight_decay, inv_scale.data_ptr<float>());) reinterpret_cast<float *>(lr.data.dptr), (adamMode_t)mode, weight_decay,
reinterpret_cast<float *>(inv_scale.data.dptr));)
AT_CUDA_CHECK(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
} }
void multi_tensor_adam_capturable_master_cuda(int chunk_size, at::Tensor noop_flag, void multi_tensor_adam_capturable_master_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<Tensor *>> tensor_lists,
at::Tensor lr, const float beta1, const float beta2, Tensor lr, const float beta1, const float beta2,
const float epsilon, at::Tensor step, const int mode, const float epsilon, Tensor step, const int mode,
const int bias_correction, const float weight_decay, const int bias_correction, const float weight_decay,
at::Tensor inv_scale) { Tensor inv_scale, const int device_id,
using namespace at; cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( tensor_lists[0][0]->dtype(), dtype,
tensor_lists[0][0].scalar_type(), 0, "adam",
multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamCapturableMasterFunctor<scalar_t_0, float>(), beta1, beta2, AdamCapturableMasterFunctor<dtype, float>(), device_id, stream, beta1,
step.data_ptr<int>(), bias_correction, epsilon, lr.data_ptr<float>(), beta2, reinterpret_cast<int *>(step.data.dptr), bias_correction,
(adamMode_t)mode, weight_decay, inv_scale.data_ptr<float>());) epsilon, reinterpret_cast<float *>(lr.data.dptr), (adamMode_t)mode,
weight_decay, reinterpret_cast<float *>(inv_scale.data.dptr));)
NVTE_CHECK_CUDA(cudaGetLastError());
}
} // namespace multi_tensor_adam
} // namespace transformer_engine
void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list,
const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode,
const int bias_correction, const float weight_decay,
const int device_id, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_adam_cuda);
using namespace transformer_engine;
multi_tensor_adam::multi_tensor_adam_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2,
epsilon, step, mode, bias_correction, weight_decay, device_id, stream);
}
void nvte_multi_tensor_adam_param_remainder_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode, const int bias_correction,
const float weight_decay, const int device_id, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_adam_param_remainder_cuda);
using namespace transformer_engine;
multi_tensor_adam::multi_tensor_adam_param_remainder_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2,
epsilon, step, mode, bias_correction, weight_decay, device_id, stream);
}
void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag,
NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, const float lr,
const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction,
const float weight_decay, const NVTEDType fp8_dtype,
const int device_id, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_adam_fp8_cuda);
using namespace transformer_engine;
multi_tensor_adam::multi_tensor_adam_fp8_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2,
epsilon, step, mode, bias_correction, weight_decay, static_cast<DType>(fp8_dtype), device_id,
stream);
}
void nvte_multi_tensor_adam_capturable_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2,
const float epsilon, NVTETensor step, const int mode, const int bias_correction,
const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_adam_capturable_cuda);
using namespace transformer_engine;
multi_tensor_adam::multi_tensor_adam_capturable_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
*reinterpret_cast<Tensor *>(lr), beta1, beta2, epsilon, *reinterpret_cast<Tensor *>(step),
mode, bias_correction, weight_decay, *reinterpret_cast<Tensor *>(inv_scale), device_id,
stream);
}
AT_CUDA_CHECK(cudaGetLastError()); void nvte_multi_tensor_adam_capturable_master_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2,
const float epsilon, NVTETensor step, const int mode, const int bias_correction,
const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_adam_capturable_master_cuda);
using namespace transformer_engine;
multi_tensor_adam::multi_tensor_adam_capturable_master_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
*reinterpret_cast<Tensor *>(lr), beta1, beta2, epsilon, *reinterpret_cast<Tensor *>(step),
mode, bias_correction, weight_decay, *reinterpret_cast<Tensor *>(inv_scale), device_id,
stream);
} }
...@@ -4,23 +4,21 @@ ...@@ -4,23 +4,21 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include <limits> #include <limits>
// Stringstream is a big hammer, but I want to rely on operator<< for dtype. // Stringstream is a big hammer, but I want to rely on operator<< for dtype.
#include <assert.h>
#include <cuda_fp8.h>
#include <transformer_engine/multi_tensor.h>
#include <transformer_engine/transformer_engine.h>
#include <sstream> #include <sstream>
#include "common/recipe/recipe_common.cuh" #include "../recipe/recipe_common.cuh"
#include "common/utils.cuh" #include "../utils.cuh"
#include "multi_tensor_apply.cuh" #include "multi_tensor_apply.cuh"
#include "type_shim.h"
namespace transformer_engine {
namespace multi_tensor_compute_scale {
#define BLOCK_SIZE 256 #define BLOCK_SIZE 256
...@@ -57,12 +55,29 @@ struct ComputeScaleAndScaleInvFunctor { ...@@ -57,12 +55,29 @@ struct ComputeScaleAndScaleInvFunctor {
} }
}; };
void multi_tensor_compute_scale_and_scale_inv_cuda( void multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, Tensor noop_flag,
int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<Tensor *>> tensor_lists,
float max_fp8, bool force_pow_2_scales, float epsilon) { float max_fp8, bool force_pow_2_scales,
using namespace at; float epsilon, const int device_id,
cudaStream_t stream) {
multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
ComputeScaleAndScaleInvFunctor(), max_fp8, force_pow_2_scales, epsilon); ComputeScaleAndScaleInvFunctor(), device_id, stream, max_fp8,
AT_CUDA_CHECK(cudaGetLastError()); force_pow_2_scales, epsilon);
NVTE_CHECK_CUDA(cudaGetLastError());
}
} // namespace multi_tensor_compute_scale
} // namespace transformer_engine
void nvte_multi_tensor_compute_scale_and_scale_inv_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, float max_fp8, int force_pow_2_scales, float epsilon,
const int device_id, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_compute_scale_and_scale_inv_cuda);
using namespace transformer_engine;
multi_tensor_compute_scale::multi_tensor_compute_scale_and_scale_inv_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), max_fp8,
force_pow_2_scales, epsilon, device_id, stream);
} }
...@@ -4,18 +4,16 @@ ...@@ -4,18 +4,16 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h> #include <assert.h>
#include <cuda_fp8.h>
#include <transformer_engine/multi_tensor.h>
#include <transformer_engine/transformer_engine.h>
#include "../utils.cuh"
#include "multi_tensor_apply.cuh" #include "multi_tensor_apply.cuh"
#include "type_shim.h"
namespace transformer_engine {
namespace multi_tensor_l2norm {
#define BLOCK_SIZE 512 #define BLOCK_SIZE 512
#define ILP 4 #define ILP 4
...@@ -31,6 +29,88 @@ __device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset, int s ...@@ -31,6 +29,88 @@ __device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset, int s
((LT *)dst)[dst_offset] = ((LT *)src)[src_offset]; // NOLINT(*) ((LT *)dst)[dst_offset] = ((LT *)src)[src_offset]; // NOLINT(*)
} }
template <typename T>
__device__ __forceinline__ T
reduce_block_into_lanes(T *x, T val, int lanes = 1,
bool share_result = false) { // lanes is intended to be <= 32.
int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
if (blockSize >= 64) {
x[tid] = val;
__syncthreads();
}
#pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
if (tid < i) x[tid] = x[tid] + x[tid + i];
__syncthreads();
}
T final;
if (tid < 32) {
if (blockSize >= 64)
final = x[tid] + x[tid + 32];
else
final = val;
// __SYNCWARP();
#pragma unroll
for (int i = 16; i >= lanes; i >>= 1) final = final + __shfl_down_sync(0xffffffff, final, i);
}
if (share_result) {
if (tid < lanes) x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
}
__syncthreads();
// Avoid potential write before read race when reduce_block_into_lanes is called back to back
return final;
}
template <typename T>
__device__ __forceinline__ T
reduce_block_into_lanes_max_op(T *x, T val, int lanes = 1,
bool share_result = false) { // lanes is intended to be <= 32.
int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
if (blockSize >= 64) {
x[tid] = val;
__syncthreads();
}
#pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
if (tid < i) x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i]));
__syncthreads();
}
T final;
if (tid < 32) {
if (blockSize >= 64)
final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32]));
else
final = val;
// __SYNCWARP();
#pragma unroll
for (int i = 16; i >= lanes; i >>= 1)
final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
}
if (share_result) {
if (tid < lanes) x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}
template <typename x_t> template <typename x_t>
struct L2NormFunctor { struct L2NormFunctor {
__device__ __forceinline__ void operator()(int chunk_size, volatile int *noop_gmem, __device__ __forceinline__ void operator()(int chunk_size, volatile int *noop_gmem,
...@@ -56,7 +136,7 @@ struct L2NormFunctor { ...@@ -56,7 +136,7 @@ struct L2NormFunctor {
x_t r_x[ILP]; x_t r_x[ILP];
for (int i = 0; i < ILP; i++) { for (int i = 0; i < ILP; i++) {
vals[i] = 0.f; vals[i] = 0.f;
r_x[i] = 0; r_x[i] = 0.f;
} }
// to make things simple, we put aligned case in a different code path // to make things simple, we put aligned case in a different code path
...@@ -126,7 +206,7 @@ struct UnscaleL2NormFunctor { ...@@ -126,7 +206,7 @@ struct UnscaleL2NormFunctor {
x_t r_x[ILP]; x_t r_x[ILP];
for (int i = 0; i < ILP; i++) { for (int i = 0; i < ILP; i++) {
vals[i] = 0.f; vals[i] = 0.f;
r_x[i] = 0; r_x[i] = 0.f;
} }
// to make things simple, we put aligned case in a different code path // to make things simple, we put aligned case in a different code path
...@@ -310,103 +390,96 @@ __global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret, ...@@ -310,103 +390,96 @@ __global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret,
} }
} }
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda( void multi_tensor_l2norm_cuda(int chunk_size, Tensor noop_flag,
int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<Tensor *>> tensor_lists, Tensor output,
at::optional<bool> per_tensor_python) { Tensor output_per_tensor, Tensor ret, Tensor ret_per_tensor,
bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false; bool per_tensor, int max_chunks_per_tensor, const int device_id,
cudaStream_t stream) {
auto float_options = tensor_lists[0][0].options().dtype(at::kFloat); TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
auto output = at::zeros({320}, float_options); tensor_lists[0][0]->dtype(), dtype,
multi_tensor_apply<1>(
at::Tensor output_per_tensor; BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, L2NormFunctor<dtype>(), device_id,
at::Tensor ret_per_tensor; stream, reinterpret_cast<float *>(output.data.dptr),
per_tensor ? reinterpret_cast<float *>(output_per_tensor.data.dptr) : nullptr, per_tensor,
int ntensors = tensor_lists[0].size(); max_chunks_per_tensor);)
int max_chunks_per_tensor = -1;
NVTE_CHECK_CUDA(cudaGetLastError());
if (per_tensor) {
for (int t = 0; t < ntensors; t++) {
int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
if (max_chunks_this_tensor > max_chunks_per_tensor)
max_chunks_per_tensor = max_chunks_this_tensor;
}
output_per_tensor = at::zeros({ntensors * max_chunks_per_tensor}, float_options);
ret_per_tensor = at::empty({ntensors}, float_options);
} else {
ret_per_tensor = at::empty({0}, float_options);
}
DISPATCH_FLOAT_HALF_AND_BFLOAT(
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda",
multi_tensor_apply<1>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
L2NormFunctor<scalar_t_0>(), output.data_ptr<float>(),
per_tensor ? output_per_tensor.data_ptr<float>() : nullptr, per_tensor,
max_chunks_per_tensor);)
AT_CUDA_CHECK(cudaGetLastError());
// AT_CUDA_CHECK(cudaDeviceSynchronize());
// This involves one more small kernel launches, but will be negligible end to end. // This involves one more small kernel launches, but will be negligible end to end.
// I could get rid of these by hacking the functor + multi tensor harness with persistence // I could get rid of these by hacking the functor + multi tensor harness with persistence
// logic, but keeping it simple for now // logic, but keeping it simple for now
auto ret = at::empty({1}, output.options()); const OptionalCUDAGuard device_guard(device_id);
const at::cuda::OptionalCUDAGuard device_guard(device_of(output)); cleanup<<<per_tensor ? tensor_lists[0].size() : 1, 512, 0, stream>>>(
auto stream = at::cuda::getCurrentCUDAStream(); reinterpret_cast<float *>(output.data.dptr),
cleanup<<<per_tensor ? ntensors : 1, 512, 0, stream>>>( per_tensor ? reinterpret_cast<float *>(output_per_tensor.data.dptr) : nullptr,
output.data_ptr<float>(), per_tensor ? output_per_tensor.data_ptr<float>() : nullptr, reinterpret_cast<float *>(ret.data.dptr),
ret.data_ptr<float>(), per_tensor ? ret_per_tensor.data_ptr<float>() : nullptr, per_tensor, per_tensor ? reinterpret_cast<float *>(ret_per_tensor.data.dptr) : nullptr, per_tensor,
max_chunks_per_tensor); max_chunks_per_tensor);
return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor);
} }
std::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda( void multi_tensor_unscale_l2norm_cuda(int chunk_size, Tensor noop_flag,
int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<Tensor *>> tensor_lists,
at::Tensor inv_scale, at::optional<bool> per_tensor_python) { Tensor output, Tensor output_per_tensor, Tensor ret,
bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false; Tensor ret_per_tensor, Tensor inv_scale, bool per_tensor,
int max_chunks_per_tensor, const int device_id,
auto float_options = tensor_lists[0][0].options().dtype(at::kFloat); cudaStream_t stream) {
auto output = at::zeros({320}, float_options); TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[0][0]->dtype(), dtype,
at::Tensor output_per_tensor; multi_tensor_apply<1>(
at::Tensor ret_per_tensor; BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, UnscaleL2NormFunctor<dtype>(), device_id,
stream, reinterpret_cast<float *>(inv_scale.data.dptr),
int ntensors = tensor_lists[0].size(); reinterpret_cast<float *>(output.data.dptr),
int max_chunks_per_tensor = -1; per_tensor ? reinterpret_cast<float *>(output_per_tensor.data.dptr) : nullptr, per_tensor,
max_chunks_per_tensor);)
if (per_tensor) {
for (int t = 0; t < ntensors; t++) { NVTE_CHECK_CUDA(cudaGetLastError());
int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
if (max_chunks_this_tensor > max_chunks_per_tensor)
max_chunks_per_tensor = max_chunks_this_tensor;
}
output_per_tensor = at::zeros({ntensors * max_chunks_per_tensor}, float_options);
ret_per_tensor = at::empty({ntensors}, float_options);
} else {
ret_per_tensor = at::empty({0}, float_options);
}
DISPATCH_FLOAT_HALF_AND_BFLOAT(
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_unscale_l2norm_cuda",
multi_tensor_apply<1>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
UnscaleL2NormFunctor<scalar_t_0>(), inv_scale.data_ptr<float>(),
output.data_ptr<float>(),
per_tensor ? output_per_tensor.data_ptr<float>() : nullptr, per_tensor,
max_chunks_per_tensor);)
AT_CUDA_CHECK(cudaGetLastError());
// AT_CUDA_CHECK(cudaDeviceSynchronize());
// This involves one more small kernel launches, but will be negligible end to end. // This involves one more small kernel launches, but will be negligible end to end.
// I could get rid of these by hacking the functor + multi tensor harness with persistence // I could get rid of these by hacking the functor + multi tensor harness with persistence
// logic, but keeping it simple for now // logic, but keeping it simple for now
auto ret = at::empty({1}, output.options()); const OptionalCUDAGuard device_guard(device_id);
const at::cuda::OptionalCUDAGuard device_guard(device_of(output)); cleanup<<<per_tensor ? tensor_lists[0].size() : 1, 512, 0, stream>>>(
auto stream = at::cuda::getCurrentCUDAStream(); reinterpret_cast<float *>(output.data.dptr),
cleanup<<<per_tensor ? ntensors : 1, 512, 0, stream>>>( per_tensor ? reinterpret_cast<float *>(output_per_tensor.data.dptr) : nullptr,
output.data_ptr<float>(), per_tensor ? output_per_tensor.data_ptr<float>() : nullptr, reinterpret_cast<float *>(ret.data.dptr),
ret.data_ptr<float>(), per_tensor ? ret_per_tensor.data_ptr<float>() : nullptr, per_tensor, per_tensor ? reinterpret_cast<float *>(ret_per_tensor.data.dptr) : nullptr, per_tensor,
max_chunks_per_tensor); max_chunks_per_tensor);
}
} // namespace multi_tensor_l2norm
} // namespace transformer_engine
void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list,
NVTETensor output, NVTETensor output_per_tensor, NVTETensor ret,
NVTETensor ret_per_tensor, int per_tensor,
int max_chunks_per_tensor, const int device_id,
cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_l2norm_cuda);
using namespace transformer_engine;
multi_tensor_l2norm::multi_tensor_l2norm_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
*reinterpret_cast<Tensor *>(output), *reinterpret_cast<Tensor *>(output_per_tensor),
*reinterpret_cast<Tensor *>(ret), *reinterpret_cast<Tensor *>(ret_per_tensor), per_tensor,
max_chunks_per_tensor, device_id, stream);
}
return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor); void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, NVTETensor output,
NVTETensor output_per_tensor, NVTETensor ret,
NVTETensor ret_per_tensor, NVTETensor inv_scale,
int per_tensor, int max_chunks_per_tensor,
const int device_id, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_unscale_l2norm_cuda);
using namespace transformer_engine;
multi_tensor_l2norm::multi_tensor_unscale_l2norm_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
*reinterpret_cast<Tensor *>(output), *reinterpret_cast<Tensor *>(output_per_tensor),
*reinterpret_cast<Tensor *>(ret), *reinterpret_cast<Tensor *>(ret_per_tensor),
*reinterpret_cast<Tensor *>(inv_scale), per_tensor, max_chunks_per_tensor, device_id, stream);
} }
...@@ -5,17 +5,62 @@ ...@@ -5,17 +5,62 @@
************************************************************************/ ************************************************************************/
#pragma once #pragma once
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <assert.h> #include <assert.h>
#include <c10/cuda/CUDAGuard.h> #include <cuda_runtime.h>
#include <transformer_engine/multi_tensor.h>
#include <transformer_engine/transformer_engine.h>
#include "common/common.h" #include "../common.h"
// This header is the one-stop shop for all your multi-tensor apply needs. // This header is the one-stop shop for all your multi-tensor apply needs.
// Change device if needed.
class OptionalCUDAGuard {
public:
explicit OptionalCUDAGuard(int new_device) {
if (new_device < 0) return;
int current_device;
NVTE_CHECK_CUDA(cudaGetDevice(&current_device));
if (new_device != current_device) {
NVTE_CHECK_CUDA(cudaSetDevice(new_device));
device_changed_ = true;
prev_device_ = current_device;
}
}
OptionalCUDAGuard(const OptionalCUDAGuard &) = delete;
OptionalCUDAGuard &operator=(const OptionalCUDAGuard &) = delete;
OptionalCUDAGuard(OptionalCUDAGuard &&other) noexcept
: prev_device_(other.prev_device_), device_changed_(other.device_changed_) {
other.device_changed_ = false;
}
OptionalCUDAGuard &operator=(OptionalCUDAGuard &&other) noexcept {
if (this != &other) {
if (device_changed_) {
cudaSetDevice(prev_device_);
}
prev_device_ = other.prev_device_;
device_changed_ = other.device_changed_;
other.device_changed_ = false;
}
return *this;
}
~OptionalCUDAGuard() {
if (device_changed_) {
NVTE_CHECK_CUDA(cudaSetDevice(prev_device_));
}
}
private:
int prev_device_;
bool device_changed_ = false;
};
// TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson) // TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson)
constexpr int depth_to_max_tensors[6] = {110, 64, 48, 36, 30, 24}; constexpr int depth_to_max_tensors[6] = {110, 64, 48, 36, 30, 24};
constexpr int depth_to_max_blocks[6] = {320, 320, 320, 320, 320, 320}; constexpr int depth_to_max_blocks[6] = {320, 320, 320, 320, 320, 320};
...@@ -46,61 +91,39 @@ __global__ void multi_tensor_apply_kernel(int64_t chunk_size, volatile int *noop ...@@ -46,61 +91,39 @@ __global__ void multi_tensor_apply_kernel(int64_t chunk_size, volatile int *noop
} }
template <int depth, bool USE_FP8 = false, typename T, typename... ArgTypes> template <int depth, bool USE_FP8 = false, typename T, typename... ArgTypes>
void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor &noop_flag, void multi_tensor_apply(int64_t block_size, int64_t chunk_size,
const std::vector<std::vector<at::Tensor>> &tensor_lists, T callable, const transformer_engine::Tensor &noop_flag,
ArgTypes... args) { std::vector<std::vector<transformer_engine::Tensor *>> tensor_lists,
if constexpr (USE_FP8) { T callable, const int device_id, cudaStream_t stream, ArgTypes... args) {
TORCH_CHECK(tensor_lists.size() == depth + 3, const size_t num_tensor_lists = tensor_lists.size();
"tensor_lists.size() != depth + 3, tensor_lists should have 3 more tensors (scale, " const size_t num_tensors_per_list = tensor_lists[0].size();
"amax, scale_inv) for fp8");
} else {
TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
}
int len0 = tensor_lists[0].size();
TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");
auto ref_device = tensor_lists[0][0].device();
TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda");
for (int l = 0; l < depth; l++) { // No range-based for because I need indices
TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists");
for (int t = 0; t < tensor_lists[l].size(); t++) {
// TODO: Print which tensor fails.
bool contiguous_memory = tensor_lists[l][t].is_contiguous();
contiguous_memory =
(contiguous_memory || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast) ||
tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast3d));
TORCH_CHECK(contiguous_memory, "A tensor was not contiguous.");
TORCH_CHECK(tensor_lists[l][t].device() == ref_device,
"A tensor was not on the same device as the first tensor");
TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch");
}
}
if constexpr (USE_FP8) { if constexpr (USE_FP8) {
TORCH_CHECK(tensor_lists[depth].size() == len0 && tensor_lists[depth + 1].size() == len0, NVTE_CHECK(num_tensor_lists == depth + 3,
"Size mismatch among tensor lists"); "tensor_lists.size() != depth + 3, tensor_lists should have 3 more tensors (scale, "
"amax, scale_inv) for fp8");
} else {
NVTE_CHECK(num_tensor_lists == depth, "tensor_lists.size() != depth");
} }
int ntensors = tensor_lists[0].size();
TensorListMetadata<depth, USE_FP8> tl; TensorListMetadata<depth, USE_FP8> tl;
const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0])); const OptionalCUDAGuard device_guard(device_id);
auto stream = at::cuda::getCurrentCUDAStream();
tl.start_tensor_this_launch = 0; tl.start_tensor_this_launch = 0;
int loc_block_info = 0; int loc_block_info = 0;
int loc_tensor_info = 0; int loc_tensor_info = 0;
for (int t = 0; t < ntensors; t++) { for (int t = 0; t < num_tensors_per_list; t++) {
tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel(); tl.sizes[loc_tensor_info] = tensor_lists[0][t]->numel();
for (int d = 0; d < depth; d++) for (int d = 0; d < depth; d++)
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr(); tl.addresses[d][loc_tensor_info] = tensor_lists[d][t]->data.dptr;
if constexpr (USE_FP8) { if constexpr (USE_FP8) {
for (int i = 0; i < 3; i++) for (int i = 0; i < 3; i++)
tl.fp8_meta_addresses[i][loc_tensor_info] = tensor_lists[depth + i][t].data_ptr(); tl.fp8_meta_addresses[i][loc_tensor_info] = tensor_lists[depth + i][t]->data.dptr;
} }
loc_tensor_info++; loc_tensor_info++;
auto chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size; auto chunks_this_tensor = (tensor_lists[0][t]->numel() + chunk_size - 1) / chunk_size;
for (auto chunk = 0; chunk < chunks_this_tensor; chunk++) { for (auto chunk = 0; chunk < chunks_this_tensor; chunk++) {
tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1; tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
...@@ -110,12 +133,12 @@ void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor ...@@ -110,12 +133,12 @@ void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const at::Tensor
bool tensors_full = bool tensors_full =
(loc_tensor_info == depth_to_max_tensors[depth - 1] && chunk == chunks_this_tensor - 1); (loc_tensor_info == depth_to_max_tensors[depth - 1] && chunk == chunks_this_tensor - 1);
bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]); bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]);
bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1); bool last_chunk = (t == num_tensors_per_list - 1 && chunk == chunks_this_tensor - 1);
if (tensors_full || blocks_full || last_chunk) { if (tensors_full || blocks_full || last_chunk) {
multi_tensor_apply_kernel<<<loc_block_info, block_size, 0, stream>>>( multi_tensor_apply_kernel<<<loc_block_info, block_size, 0, stream>>>(
chunk_size, noop_flag.data_ptr<int>(), tl, callable, args...); chunk_size, reinterpret_cast<int *>(noop_flag.data.dptr), tl, callable, args...);
AT_CUDA_CHECK(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
// Reset. The control flow possibilities here make my brain hurt. // Reset. The control flow possibilities here make my brain hurt.
loc_block_info = 0; loc_block_info = 0;
......
...@@ -4,19 +4,20 @@ ...@@ -4,19 +4,20 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h> #include <assert.h>
#include <cuda_fp8.h>
// Stringstream is a big hammer, but I want to rely on operator<< for dtype. // Stringstream is a big hammer, but I want to rely on operator<< for dtype.
#include <transformer_engine/multi_tensor.h>
#include <transformer_engine/transformer_engine.h>
#include <iostream>
#include <sstream> #include <sstream>
#include "../utils.cuh"
#include "multi_tensor_apply.cuh" #include "multi_tensor_apply.cuh"
#include "type_shim.h"
namespace transformer_engine {
namespace multi_tensor_scale {
#define BLOCK_SIZE 512 #define BLOCK_SIZE 512
#define ILP 4 #define ILP 4
...@@ -66,7 +67,7 @@ struct ScaleFunctor { ...@@ -66,7 +67,7 @@ struct ScaleFunctor {
#pragma unroll #pragma unroll
for (int ii = 0; ii < ILP; ii++) { for (int ii = 0; ii < ILP; ii++) {
r_out[ii] = static_cast<float>(r_in[ii]) * scale; r_out[ii] = static_cast<float>(r_in[ii]) * scale;
finite = finite && isfinite(r_in[ii]); finite = finite && isfinite(static_cast<float>(r_in[ii]));
} }
// store // store
load_store(out, r_out, i_start, 0); load_store(out, r_out, i_start, 0);
...@@ -76,7 +77,7 @@ struct ScaleFunctor { ...@@ -76,7 +77,7 @@ struct ScaleFunctor {
for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
#pragma unroll #pragma unroll
for (int ii = 0; ii < ILP; ii++) { for (int ii = 0; ii < ILP; ii++) {
r_in[ii] = 0; r_in[ii] = 0.f;
int i = i_start + threadIdx.x + ii * blockDim.x; int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) r_in[ii] = in[i]; if (i < n && i < chunk_size) r_in[ii] = in[i];
} }
...@@ -88,7 +89,7 @@ struct ScaleFunctor { ...@@ -88,7 +89,7 @@ struct ScaleFunctor {
#pragma unroll #pragma unroll
for (int ii = 0; ii < ILP; ii++) { for (int ii = 0; ii < ILP; ii++) {
r_out[ii] = static_cast<float>(r_in[ii]) * scale; r_out[ii] = static_cast<float>(r_in[ii]) * scale;
finite = finite && isfinite(r_in[ii]); finite = finite && isfinite(static_cast<float>(r_in[ii]));
} }
#pragma unroll #pragma unroll
for (int ii = 0; ii < ILP; ii++) { for (int ii = 0; ii < ILP; ii++) {
...@@ -101,20 +102,29 @@ struct ScaleFunctor { ...@@ -101,20 +102,29 @@ struct ScaleFunctor {
} }
}; };
void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag, void multi_tensor_scale_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, float scale) { std::vector<std::vector<Tensor *>> tensor_lists, float scale,
using namespace at; const int device_id, cudaStream_t stream) {
// The output (downscaled) type is always float. TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
// If build times suffer, think about where to put this dispatch, tensor_lists[0][0]->dtype(), p_in_type,
// and what logic should be moved out of multi_tensor_apply. TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[1][0]->dtype(), g_in_type,
DISPATCH_FLOAT_HALF_AND_BFLOAT(
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_scale_cuda",
DISPATCH_FLOAT_HALF_AND_BFLOAT(
tensor_lists[1][0].scalar_type(), 1, "multi_tensor_scale_cuda",
multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
ScaleFunctor<scalar_t_0, scalar_t_1>(), scale);)) ScaleFunctor<p_in_type, g_in_type>(), device_id, stream, scale);))
AT_CUDA_CHECK(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
}
} // namespace multi_tensor_scale
} // namespace transformer_engine
void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list,
float scale, const int device_id, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_scale_cuda);
using namespace transformer_engine;
// AT_CUDA_CHECK(cudaDeviceSynchronize()); multi_tensor_scale::multi_tensor_scale_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), scale, device_id,
stream);
} }
...@@ -4,14 +4,16 @@ ...@@ -4,14 +4,16 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <assert.h> #include <assert.h>
#include <cuda_fp8.h>
#include <transformer_engine/multi_tensor.h>
#include <transformer_engine/transformer_engine.h>
#include "../utils.cuh"
#include "multi_tensor_apply.cuh" #include "multi_tensor_apply.cuh"
#include "type_shim.h"
namespace transformer_engine {
namespace multi_tensor_sgd {
#define BLOCK_SIZE 512 #define BLOCK_SIZE 512
#define ILP 4 #define ILP 4
...@@ -54,9 +56,9 @@ struct SGDFunctor { ...@@ -54,9 +56,9 @@ struct SGDFunctor {
T_weight* mom_in = reinterpret_cast<T_weight*>(tl.addresses[2][tensor_loc]); T_weight* mom_in = reinterpret_cast<T_weight*>(tl.addresses[2][tensor_loc]);
mom_in += chunk_idx * chunk_size; mom_in += chunk_idx * chunk_size;
at::Half* model_weights_out = nullptr; fp16* model_weights_out = nullptr;
if (N == 4) { if (N == 4) {
model_weights_out = (at::Half*)tl.addresses[3][tensor_loc]; model_weights_out = reinterpret_cast<fp16*>(tl.addresses[3][tensor_loc]);
model_weights_out += chunk_idx * chunk_size; model_weights_out += chunk_idx * chunk_size;
} }
...@@ -112,7 +114,7 @@ struct SGDFunctor { ...@@ -112,7 +114,7 @@ struct SGDFunctor {
weight_in[i] += (-lr * incoming_grads[ii]); weight_in[i] += (-lr * incoming_grads[ii]);
// if necessary, write out an fp16 copy of the weights // if necessary, write out an fp16 copy of the weights
if (N == 4) model_weights_out[i] = static_cast<at::Half>(weight_in[i]); if (N == 4) model_weights_out[i] = static_cast<fp16>(weight_in[i]);
// also write out the new momentum // also write out the new momentum
if (momentum != 0.f) mom_in[i] = incoming_moms[ii]; if (momentum != 0.f) mom_in[i] = incoming_moms[ii];
...@@ -122,23 +124,23 @@ struct SGDFunctor { ...@@ -122,23 +124,23 @@ struct SGDFunctor {
} }
}; };
void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag, void multi_tensor_sgd_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, float wd, std::vector<std::vector<Tensor*>> tensor_lists, float wd, float momentum,
float momentum, float dampening, float lr, bool nesterov, bool first_run, float dampening, float lr, bool nesterov, bool first_run,
bool wd_after_momentum, float scale) { bool wd_after_momentum, float scale, const int device_id,
auto num_tensors = tensor_lists.size(); cudaStream_t stream) {
auto grad_type = tensor_lists[0][0].scalar_type(); const size_t num_tensor_lists = tensor_lists.size();
auto weight_type = tensor_lists[1][0].scalar_type(); const size_t num_tensors_per_list = tensor_lists[0].size();
if (num_tensors == 4) { auto grad_type = tensor_lists[0][0]->dtype();
for (int i = 0; i < tensor_lists[3].size(); i++) auto weight_type = tensor_lists[1][0]->dtype();
TORCH_CHECK(tensor_lists[3][i].scalar_type() == at::ScalarType::Half,
"Additional output tensors should always be fp16."); if (num_tensor_lists == 4) {
for (int i = 0; i < num_tensors_per_list; i++)
NVTE_CHECK(tensor_lists[3][i]->dtype() == DType::kFloat16,
"Additional output tensors should always be fp16.");
} }
TORCH_CHECK(noop_flag.device() == tensor_lists[0][0].device(),
"expected noop flag to be on the same device as tensors");
// We have 3 possibilities to handle here, in terms of // We have 3 possibilities to handle here, in terms of
// grad_type, param_type, momentum_type, requires_fp16_copy // grad_type, param_type, momentum_type, requires_fp16_copy
// 1. fp16, fp16, fp16, No // 1. fp16, fp16, fp16, No
...@@ -150,54 +152,51 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag, ...@@ -150,54 +152,51 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
// we don't want the majority of them. // we don't want the majority of them.
// Case 1. fp16, fp16, fp16, No // Case 1. fp16, fp16, fp16, No
if (grad_type == at::ScalarType::Half && weight_type == at::ScalarType::Half && if (grad_type == DType::kFloat16 && weight_type == DType::kFloat16 && num_tensor_lists == 3) {
num_tensors == 3) {
multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
SGDFunctor<3, at::Half, at::Half>(), wd, momentum, dampening, lr, SGDFunctor<3, fp16, fp16>(), device_id, stream, wd, momentum, dampening,
nesterov, first_run, wd_after_momentum, scale); lr, nesterov, first_run, wd_after_momentum, scale);
} }
// Case 2. fp16, fp32, fp32, No
// else if (grad_type == at::ScalarType::Half &&
// weight_type == at::ScalarType::Float &&
// num_tensors == 3) {
// multi_tensor_apply<3>(
// BLOCK_SIZE,
// chunk_size,
// noop_flag,
// tensor_lists,
// SGDFunctor<3, at::Half, float>(),
// wd,
// momentum,
// dampening,
// lr,
// nesterov,
// first_run,
// wd_after_momentum);
// }
// Case 2. fp32, fp32, fp32, No // Case 2. fp32, fp32, fp32, No
else if (grad_type == at::ScalarType::Float && // NOLINT(*) else if (grad_type == DType::kFloat32 && // NOLINT(*)
weight_type == at::ScalarType::Float && num_tensors == 3) { weight_type == DType::kFloat32 && num_tensor_lists == 3) {
multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
SGDFunctor<3, float, float>(), wd, momentum, dampening, lr, nesterov, SGDFunctor<3, float, float>(), device_id, stream, wd, momentum, dampening,
first_run, wd_after_momentum, scale); lr, nesterov, first_run, wd_after_momentum, scale);
} }
// Case 3. fp16, fp32, fp32, Yes // Case 3. fp16, fp32, fp32, Yes
else if (grad_type == at::ScalarType::Half && // NOLINT(*) else if (grad_type == DType::kFloat16 && // NOLINT(*)
weight_type == at::ScalarType::Float && num_tensors == 4) { weight_type == DType::kFloat32 && num_tensor_lists == 4) {
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
SGDFunctor<4, at::Half, float>(), wd, momentum, dampening, lr, nesterov, SGDFunctor<4, fp16, float>(), device_id, stream, wd, momentum, dampening,
first_run, wd_after_momentum, scale); lr, nesterov, first_run, wd_after_momentum, scale);
} }
// Case 4. fp32, fp32, fp32, Yes // Case 4. fp32, fp32, fp32, Yes
else if (grad_type == at::ScalarType::Float && // NOLINT(*) else if (grad_type == DType::kFloat32 && // NOLINT(*)
weight_type == at::ScalarType::Float && num_tensors == 4) { weight_type == DType::kFloat32 && num_tensor_lists == 4) {
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
SGDFunctor<4, float, float>(), wd, momentum, dampening, lr, nesterov, SGDFunctor<4, float, float>(), device_id, stream, wd, momentum, dampening,
first_run, wd_after_momentum, scale); lr, nesterov, first_run, wd_after_momentum, scale);
} else { } else {
AT_ERROR("multi_tensor_sgd only supports some combinations of gradient & weight types. Given: ", NVTE_ERROR("Unsupported combination of weight and gradient types.");
"gradient: ", grad_type, ", weight: ", weight_type, ", num_lists: ", num_tensors);
} }
AT_CUDA_CHECK(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
}
} // namespace multi_tensor_sgd
} // namespace transformer_engine
void nvte_multi_tensor_sgd_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor** tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list,
float wd, float momentum, float dampening, float lr, int nesterov,
int first_run, int wd_after_momentum, float scale,
const int device_id, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_sgd_cuda);
using namespace transformer_engine;
multi_tensor_sgd::multi_tensor_sgd_cuda(
chunk_size, *reinterpret_cast<Tensor*>(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), wd, momentum,
dampening, lr, nesterov, first_run, wd_after_momentum, scale, device_id, stream);
} }
...@@ -49,11 +49,11 @@ std::string to_string(const DType type) { ...@@ -49,11 +49,11 @@ std::string to_string(const DType type) {
std::string to_string(const NVTEScalingMode &mode) { std::string to_string(const NVTEScalingMode &mode) {
switch (mode) { switch (mode) {
case NVTE_DELAYED_TENSOR_SCALING: case NVTE_DELAYED_TENSOR_SCALING:
return "Delayed Tensor Scaling"; return "NVTE_DELAYED_TENSOR_SCALING";
case NVTE_MXFP8_1D_SCALING: case NVTE_MXFP8_1D_SCALING:
return "MXFP8 1D Scaling"; return "NVTE_MXFP8_1D_SCALING";
case NVTE_INVALID_SCALING: case NVTE_INVALID_SCALING:
return "Invalid Scaling"; return "NVTE_INVALID_SCALING";
} }
return "Invalid Scaling"; return "Invalid Scaling";
} }
......
...@@ -97,6 +97,43 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor) ...@@ -97,6 +97,43 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor)
return makeTransformerEngineTensor(tensor.data_ptr(), shape, dtype); return makeTransformerEngineTensor(tensor.data_ptr(), shape, dtype);
} }
std::tuple<std::vector<transformer_engine::TensorWrapper>, std::vector<std::vector<NVTETensor>>,
std::vector<NVTETensor*>, size_t, size_t>
makeTransformerEngineTensorList(std::vector<std::vector<at::Tensor>> at_tensor_lists) {
size_t num_lists = at_tensor_lists.size();
NVTE_CHECK(num_lists > 0, "List of tensors is empty.");
size_t num_tensors = at_tensor_lists[0].size();
std::vector<std::vector<NVTETensor>> nvte_tensor_lists;
std::vector<NVTETensor*> nvte_tensor_list_ptrs;
std::vector<transformer_engine::TensorWrapper> tensorWrappers;
nvte_tensor_lists.reserve(num_lists);
nvte_tensor_list_ptrs.reserve(num_lists);
tensorWrappers.reserve(num_lists * num_tensors);
for (const auto& at_list : at_tensor_lists) {
NVTE_CHECK(at_list.size() == num_tensors, "Wrong number of tensors");
std::vector<NVTETensor> te_list;
te_list.reserve(num_tensors);
for (const auto& at_tensor : at_list) {
tensorWrappers.push_back(makeTransformerEngineTensor(at_tensor));
te_list.push_back(tensorWrappers.back().data());
}
nvte_tensor_lists.push_back(std::move(te_list));
}
for (auto& te_tensor_list : nvte_tensor_lists) {
nvte_tensor_list_ptrs.push_back(te_tensor_list.data());
}
return std::make_tuple(std::move(tensorWrappers), std::move(nvte_tensor_lists),
std::move(nvte_tensor_list_ptrs), num_lists, num_tensors);
}
transformer_engine::TensorWrapper makeTransformerEngineTensor( transformer_engine::TensorWrapper makeTransformerEngineTensor(
void* data_ptr, const std::vector<size_t>& shape, const transformer_engine::DType type, void* data_ptr, const std::vector<size_t>& shape, const transformer_engine::DType type,
void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, std::vector<size_t> scale_inv_shape, void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, std::vector<size_t> scale_inv_shape,
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include <transformer_engine/fused_attn.h> #include <transformer_engine/fused_attn.h>
#include <transformer_engine/fused_rope.h> #include <transformer_engine/fused_rope.h>
#include <transformer_engine/gemm.h> #include <transformer_engine/gemm.h>
#include <transformer_engine/multi_tensor.h>
#include <transformer_engine/normalization.h> #include <transformer_engine/normalization.h>
#include <transformer_engine/padding.h> #include <transformer_engine/padding.h>
#include <transformer_engine/permutation.h> #include <transformer_engine/permutation.h>
...@@ -219,6 +220,8 @@ transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, ...@@ -219,6 +220,8 @@ transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid,
inline at::ScalarType GetATenDType(transformer_engine::DType t) { inline at::ScalarType GetATenDType(transformer_engine::DType t) {
switch (t) { switch (t) {
case transformer_engine::DType::kInt16:
return torch::kInt16;
case transformer_engine::DType::kInt32: case transformer_engine::DType::kInt32:
return torch::kInt32; return torch::kInt32;
case transformer_engine::DType::kInt64: case transformer_engine::DType::kInt64:
...@@ -256,6 +259,8 @@ inline transformer_engine::DType GetTransformerEngineDType(at::ScalarType t) { ...@@ -256,6 +259,8 @@ inline transformer_engine::DType GetTransformerEngineDType(at::ScalarType t) {
return transformer_engine::DType::kByte; return transformer_engine::DType::kByte;
case torch::kByte: case torch::kByte:
return transformer_engine::DType::kByte; return transformer_engine::DType::kByte;
case torch::kInt16:
return transformer_engine::DType::kInt16;
case torch::kInt32: case torch::kInt32:
return transformer_engine::DType::kInt32; return transformer_engine::DType::kInt32;
case torch::kInt64: case torch::kInt64:
...@@ -293,6 +298,10 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr, ...@@ -293,6 +298,10 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr,
transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor); transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor);
std::tuple<std::vector<transformer_engine::TensorWrapper>, std::vector<std::vector<NVTETensor>>,
std::vector<NVTETensor*>, size_t, size_t>
makeTransformerEngineTensorList(std::vector<std::vector<at::Tensor>> at_tensor_lists);
TensorWrapper makeTransformerEngineTensor(py::handle tensor, py::handle quantizer); TensorWrapper makeTransformerEngineTensor(py::handle tensor, py::handle quantizer);
transformer_engine::TensorWrapper makeTransformerEngineTensor( transformer_engine::TensorWrapper makeTransformerEngineTensor(
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, const float lr,
const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction,
const float weight_decay) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_adam_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists,
num_tensors, lr, beta1, beta2, epsilon, step, mode, bias_correction,
weight_decay, device_id, at::cuda::getCurrentCUDAStream());
}
void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode,
const int bias_correction, const float weight_decay) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_adam_param_remainder_cuda(
chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, lr, beta1,
beta2, epsilon, step, mode, bias_correction, weight_decay, device_id,
at::cuda::getCurrentCUDAStream());
}
void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, const float lr,
const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction,
const float weight_decay, DType fp8_dtype) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_adam_fp8_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(),
num_lists, num_tensors, lr, beta1, beta2, epsilon, step, mode,
bias_correction, weight_decay, static_cast<NVTEDType>(fp8_dtype),
device_id, at::cuda::getCurrentCUDAStream());
}
void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor lr, const float beta1, const float beta2,
const float epsilon, at::Tensor step, const int mode,
const int bias_correction, const float weight_decay,
at::Tensor inv_scale) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
auto lr_cu = makeTransformerEngineTensor(lr);
auto step_cu = makeTransformerEngineTensor(step);
auto inv_scale_cu = makeTransformerEngineTensor(inv_scale);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_adam_capturable_cuda(
chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors,
lr_cu.data(), beta1, beta2, epsilon, step_cu.data(), mode, bias_correction, weight_decay,
inv_scale_cu.data(), device_id, at::cuda::getCurrentCUDAStream());
}
void multi_tensor_adam_capturable_master_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor lr, const float beta1, const float beta2,
const float epsilon, at::Tensor step, const int mode,
const int bias_correction, const float weight_decay,
at::Tensor inv_scale) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
auto lr_cu = makeTransformerEngineTensor(lr);
auto step_cu = makeTransformerEngineTensor(step);
auto inv_scale_cu = makeTransformerEngineTensor(inv_scale);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_adam_capturable_master_cuda(
chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors,
lr_cu.data(), beta1, beta2, epsilon, step_cu.data(), mode, bias_correction, weight_decay,
inv_scale_cu.data(), device_id, at::cuda::getCurrentCUDAStream());
}
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
void multi_tensor_compute_scale_and_scale_inv_cuda(
int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
float max_fp8, bool force_pow_2_scales, float epsilon) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_compute_scale_and_scale_inv_cuda(
chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, max_fp8,
force_pow_2_scales, epsilon, device_id, at::cuda::getCurrentCUDAStream());
}
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false;
auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
auto output = at::zeros({320}, float_options);
at::Tensor output_per_tensor;
at::Tensor ret_per_tensor;
auto ret = at::empty({1}, output.options());
int ntensors = tensor_lists[0].size();
int max_chunks_per_tensor = -1;
if (per_tensor) {
for (int t = 0; t < ntensors; t++) {
int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
if (max_chunks_this_tensor > max_chunks_per_tensor)
max_chunks_per_tensor = max_chunks_this_tensor;
}
output_per_tensor = at::zeros({ntensors * max_chunks_per_tensor}, float_options);
ret_per_tensor = at::empty({ntensors}, float_options);
} else {
output_per_tensor = at::empty({0}, float_options);
ret_per_tensor = at::empty({0}, float_options);
}
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
auto output_cu = makeTransformerEngineTensor(output);
auto output_per_tensor_cu = makeTransformerEngineTensor(output_per_tensor);
auto ret_cu = makeTransformerEngineTensor(ret);
auto ret_per_tensor_cu = makeTransformerEngineTensor(ret_per_tensor);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_l2norm_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists,
num_tensors, output_cu.data(), output_per_tensor_cu.data(),
ret_cu.data(), ret_per_tensor_cu.data(), per_tensor,
max_chunks_per_tensor, device_id, at::cuda::getCurrentCUDAStream());
return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor);
}
std::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda(
int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor inv_scale, at::optional<bool> per_tensor_python) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false;
auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
auto output = at::zeros({320}, float_options);
at::Tensor output_per_tensor;
at::Tensor ret_per_tensor;
int ntensors = tensor_lists[0].size();
int max_chunks_per_tensor = -1;
// Create output tensors for multi scale L2 norm kernel.
if (per_tensor) {
for (int t = 0; t < ntensors; t++) {
int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
if (max_chunks_this_tensor > max_chunks_per_tensor)
max_chunks_per_tensor = max_chunks_this_tensor;
}
output_per_tensor = at::zeros({ntensors * max_chunks_per_tensor}, float_options);
ret_per_tensor = at::empty({ntensors}, float_options);
} else {
output_per_tensor = at::empty({0}, float_options);
ret_per_tensor = at::empty({0}, float_options);
}
auto ret = at::empty({1}, output.options());
auto noop_flag_cu = makeTransformerEngineTensor(noop_flag);
auto [_, __, tensor_lists_ptr, num_lists, num_tensors] =
makeTransformerEngineTensorList(tensor_lists);
auto output_cu = makeTransformerEngineTensor(output);
auto output_per_tensor_cu = makeTransformerEngineTensor(output_per_tensor);
auto ret_cu = makeTransformerEngineTensor(ret);
auto ret_per_tensor_cu = makeTransformerEngineTensor(ret_per_tensor);
auto inv_scale_cu = makeTransformerEngineTensor(inv_scale);
int device_id = tensor_lists[0][0].device().index();
nvte_multi_tensor_unscale_l2norm_cuda(
chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors,
output_cu.data(), output_per_tensor_cu.data(), ret_cu.data(), ret_per_tensor_cu.data(),
inv_scale_cu.data(), per_tensor, max_chunks_per_tensor, device_id,
at::cuda::getCurrentCUDAStream());
return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment