Unverified Commit 4292653c authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Avoid memory allocations and deallocations when creating NVTETensor (#1813)



* Changed the Tensor allocation strategy
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fixes
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Disable debug flag
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix the double free error
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fixed pyTorch recipe extension
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Hide TensorAllocator and fix the usage in LayerNorm
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Cleaning
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix permutation
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

---------
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 41909dc8
......@@ -90,11 +90,12 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
// Compute FP8 transpose if required
if (z->has_columnwise_data() && is_tensor_scaling(z->scaling_mode)) {
Tensor transpose_data;
transpose_data.data = z->columnwise_data;
transpose_data.scaling_mode = z->scaling_mode;
nvte_transpose(reinterpret_cast<NVTETensor>(z), reinterpret_cast<NVTETensor>(&transpose_data),
stream);
NVTETensor transpose_data = nvte_create_tensor(z->scaling_mode);
auto *t = convertNVTETensor(transpose_data);
t->data = z->columnwise_data;
nvte_transpose(static_cast<NVTETensor>(*z), transpose_data, stream);
nvte_destroy_tensor(transpose_data);
}
return;
......@@ -171,10 +172,9 @@ void nvte_rmsnorm_fwd(const NVTETensor x, // Nxhidden_size
cudaStream_t stream) {
NVTE_API_CALL(nvte_rmsnorm_fwd);
using namespace transformer_engine;
rmsnorm_fwd(*reinterpret_cast<const Tensor *>(x), *reinterpret_cast<const Tensor *>(gamma),
epsilon, reinterpret_cast<Tensor *>(z), reinterpret_cast<Tensor *>(rsigma),
reinterpret_cast<Tensor *>(workspace), multiprocessorCount, zero_centered_gamma,
stream);
rmsnorm_fwd(*convertNVTETensorCheck(x), *convertNVTETensorCheck(gamma), epsilon,
convertNVTETensor(z), convertNVTETensor(rsigma), convertNVTETensor(workspace),
multiprocessorCount, zero_centered_gamma, stream);
}
void nvte_rmsnorm_bwd(const NVTETensor dz, // Nxhidden_size
......@@ -186,9 +186,8 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, // Nxhidden_size
cudaStream_t stream) {
NVTE_API_CALL(nvte_rmsnorm_bwd);
using namespace transformer_engine;
rmsnorm_bwd(*reinterpret_cast<const Tensor *>(dz), *reinterpret_cast<const Tensor *>(x),
*reinterpret_cast<const Tensor *>(rsigma), *reinterpret_cast<const Tensor *>(gamma),
reinterpret_cast<Tensor *>(dx), reinterpret_cast<Tensor *>(dgamma),
reinterpret_cast<Tensor *>(workspace), multiprocessorCount, zero_centered_gamma,
stream);
rmsnorm_bwd(*convertNVTETensorCheck(dz), *convertNVTETensorCheck(x),
*convertNVTETensorCheck(rsigma), *convertNVTETensorCheck(gamma),
convertNVTETensor(dx), convertNVTETensor(dgamma), convertNVTETensor(workspace),
multiprocessorCount, zero_centered_gamma, stream);
}
......@@ -318,22 +318,16 @@ void nvte_permute(const NVTETensor input, NVTETensor output, const NVTETensor so
NVTETensor row_id_map, const NVTETensor prob, NVTETensor prob_grad,
const NVTETensor input_fwd, const int num_rows, const int topK,
const int num_cols, const int num_out_tokens, cudaStream_t stream) {
using namespace transformer_engine;
NVTE_API_CALL(nvte_permute);
const transformer_engine::Tensor *input_cu =
reinterpret_cast<const transformer_engine::Tensor *>(input);
const transformer_engine::Tensor *output_cu =
reinterpret_cast<const transformer_engine::Tensor *>(output);
const transformer_engine::Tensor *sorted_row_id_cu =
reinterpret_cast<const transformer_engine::Tensor *>(sorted_row_id);
const transformer_engine::Tensor *row_id_map_cu =
reinterpret_cast<const transformer_engine::Tensor *>(row_id_map);
const transformer_engine::Tensor *prob_cu =
reinterpret_cast<const transformer_engine::Tensor *>(prob);
const transformer_engine::Tensor *prob_grad_cu =
reinterpret_cast<const transformer_engine::Tensor *>(prob_grad);
const transformer_engine::Tensor *input_fwd_cu =
reinterpret_cast<const transformer_engine::Tensor *>(input_fwd);
const Tensor *input_cu = convertNVTETensorCheck(input);
const Tensor *output_cu = convertNVTETensorCheck(output);
const Tensor *sorted_row_id_cu = convertNVTETensorCheck(sorted_row_id);
const Tensor *row_id_map_cu = convertNVTETensorCheck(row_id_map);
const Tensor *prob_cu = convertNVTETensorCheck(prob);
const Tensor *prob_grad_cu = convertNVTETensorCheck(prob_grad);
const Tensor *input_fwd_cu = convertNVTETensorCheck(input_fwd);
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
input_cu->data.dtype, T,
......@@ -350,16 +344,13 @@ void nvte_permute(const NVTETensor input, NVTETensor output, const NVTETensor so
void nvte_unpermute(const NVTETensor input, NVTETensor output, NVTETensor row_id_map,
const NVTETensor prob, const int num_rows, const int topK, const int num_cols,
cudaStream_t stream) {
using namespace transformer_engine;
NVTE_API_CALL(nvte_unpermute);
const transformer_engine::Tensor *input_cu =
reinterpret_cast<const transformer_engine::Tensor *>(input);
const transformer_engine::Tensor *output_cu =
reinterpret_cast<const transformer_engine::Tensor *>(output);
const transformer_engine::Tensor *row_id_map_cu =
reinterpret_cast<const transformer_engine::Tensor *>(row_id_map);
const transformer_engine::Tensor *prob_cu =
reinterpret_cast<const transformer_engine::Tensor *>(prob);
const Tensor *input_cu = convertNVTETensorCheck(input);
const Tensor *output_cu = convertNVTETensorCheck(output);
const Tensor *row_id_map_cu = convertNVTETensorCheck(row_id_map);
const Tensor *prob_cu = convertNVTETensorCheck(prob);
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
input_cu->data.dtype, T,
......
......@@ -108,7 +108,7 @@ void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaSt
// Check input tensor
NVTE_CHECK(input_ != nullptr, "Invalid input tensor (got NULL)");
const auto &input = *reinterpret_cast<const Tensor *>(input_);
const auto &input = *convertNVTETensorCheck(input_);
NVTE_CHECK(input.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Input tensor for amax computation must unquantized, "
"but got scaling_mode=",
......@@ -121,7 +121,7 @@ void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaSt
// Check output tensor
NVTE_CHECK(output_ != nullptr, "Invalid output tensor (got NULL)");
auto &output = *reinterpret_cast<Tensor *>(output_);
auto &output = *convertNVTETensorCheck(output_);
NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Output tensor for amax computation must be FP8 tensor with per-tensor scaling, "
"but got scaling_mode=",
......@@ -166,7 +166,7 @@ void nvte_compute_scale_from_amax(NVTETensor output_, const NVTEQuantizationConf
// Check output tensor
NVTE_CHECK(output_ != nullptr, "Invalid output tensor (got NULL)");
auto &output = *reinterpret_cast<Tensor *>(output_);
auto &output = *convertNVTETensorCheck(output_);
NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Tensor must be FP8 tensor with per-tensor scaling, "
"but got scaling_mode=",
......
......@@ -397,9 +397,9 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update(
NVTE_API_CALL(nvte_delayed_scaling_recipe_amax_and_scale_update);
using namespace transformer_engine;
delayed_scaling_recipe::amax_and_scale_update(
*reinterpret_cast<const Tensor*>(amax_history), *reinterpret_cast<const Tensor*>(scale),
reinterpret_cast<Tensor*>(updated_amax_history), reinterpret_cast<Tensor*>(updated_scale),
amax_compute_algo, static_cast<DType>(fp8_dtype), margin, stream);
*convertNVTETensorCheck(amax_history), *convertNVTETensorCheck(scale),
convertNVTETensor(updated_amax_history), convertNVTETensor(updated_scale), amax_compute_algo,
static_cast<DType>(fp8_dtype), margin, stream);
}
void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction(
......@@ -411,10 +411,10 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction(
size_t num_tensors = amax_histories.size();
std::vector<Tensor*> t_amax_histories, t_scales;
for (size_t i = 0; i < num_tensors; i++) {
t_amax_histories.push_back(reinterpret_cast<Tensor*>(amax_histories[i]));
t_scales.push_back(reinterpret_cast<Tensor*>(scales[i]));
t_amax_histories.push_back(convertNVTETensor(amax_histories[i]));
t_scales.push_back(convertNVTETensor(scales[i]));
}
delayed_scaling_recipe::amax_and_scale_update_after_reduction(
*reinterpret_cast<const Tensor*>(amax_reduction_buffer), t_amax_histories, t_scales,
amax_compute_algo, static_cast<DType>(fp8_dtype), margin, stream);
*convertNVTETensorCheck(amax_reduction_buffer), t_amax_histories, t_scales, amax_compute_algo,
static_cast<DType>(fp8_dtype), margin, stream);
}
......@@ -227,8 +227,8 @@ void nvte_fp8_block_scaling_compute_partial_amax(const NVTETensor inp, NVTETenso
NVTE_API_CALL(nvte_fp8_block_scaling_compute_partial_amax);
using namespace transformer_engine;
fp8_block_scaling_recipe::fp8_block_scaling_compute_partial_amax(
*reinterpret_cast<const Tensor *>(inp), *reinterpret_cast<Tensor *>(amax), h, w,
amax_stride_h, amax_stride_w, start_offset, block_len, stream);
*convertNVTETensorCheck(inp), *convertNVTETensorCheck(amax), h, w, amax_stride_h,
amax_stride_w, start_offset, block_len, stream);
}
void nvte_fp8_block_scaling_partial_cast(const NVTETensor inp, NVTETensor out,
......@@ -239,7 +239,7 @@ void nvte_fp8_block_scaling_partial_cast(const NVTETensor inp, NVTETensor out,
NVTE_API_CALL(nvte_fp8_block_scaling_partial_cast);
using namespace transformer_engine;
fp8_block_scaling_recipe::fp8_block_scaling_partial_cast(
*reinterpret_cast<const Tensor *>(inp), *reinterpret_cast<Tensor *>(out),
*reinterpret_cast<const Tensor *>(scale), h, w, scale_stride_h, scale_stride_w, start_offset,
block_len, static_cast<DType>(out_dtype), stream);
*convertNVTETensorCheck(inp), *convertNVTETensorCheck(out), *convertNVTETensorCheck(scale), h,
w, scale_stride_h, scale_stride_w, start_offset, block_len, static_cast<DType>(out_dtype),
stream);
}
......@@ -333,6 +333,5 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_swizzle_scaling_factors);
using namespace transformer_engine;
swizzle_scaling_factors(reinterpret_cast<const Tensor*>(input), reinterpret_cast<Tensor*>(output),
stream);
swizzle_scaling_factors(convertNVTETensorCheck(input), convertNVTETensorCheck(output), stream);
}
......@@ -6,11 +6,15 @@
#include <transformer_engine/transformer_engine.h>
#include <atomic>
#include <climits>
#include <cstring>
#include <iostream>
#include <mutex>
#include "common.h"
#include "common/util/cuda_runtime.h"
#include "common/util/logging.h"
namespace transformer_engine {
......@@ -192,24 +196,139 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
CheckScaleTensorShape(t, name);
}
class TensorAllocator {
public:
static TensorAllocator &instance() {
static TensorAllocator allocator;
return allocator;
}
~TensorAllocator() {}
NVTETensor Allocate(NVTEScalingMode mode) {
std::lock_guard<std::mutex> lock(mutex);
if (!free_list.empty()) {
uintptr_t index = free_list.back();
NVTETensor ret = reinterpret_cast<NVTETensor>(index);
free_list.pop_back();
if (debug) {
std::cout << "Allocated " << index
<< " from free list. Free list size: " << free_list.size() << " and capacity "
<< free_list.capacity() << std::endl;
}
// 1-based indexing
memory[index - 1].scaling_mode = mode;
return ret;
}
if (memory.size() < memory.capacity()) {
memory.emplace_back();
Tensor &t = memory.back();
size = memory.size();
// 1-based indexing
uintptr_t index = memory.size();
if (debug) {
std::cout << "Allocated " << index << ". Memory size: " << memory.size() << " and capacity "
<< memory.capacity() << std::endl;
}
t.scaling_mode = mode;
t.nvte_tensor = reinterpret_cast<NVTETensor>(index);
return reinterpret_cast<NVTETensor>(index);
}
NVTE_ERROR("Cannot allocate a new NVTETensor. Maximum number of tensors reached: ",
MAX_TENSOR_NUM, ". There is probably a memory leak in your application.");
}
void Free(NVTETensor t) {
std::lock_guard<std::mutex> lock(mutex);
uintptr_t index = reinterpret_cast<uintptr_t>(t);
if (index == 0) return;
NVTE_CHECK(index <= memory.size(), "Invalid tensor.");
free_list.push_back(index);
// Clean up
memory[index - 1].clear();
if (debug) {
std::cout << "Freed " << index << ". Free list size: " << free_list.size() << " and capacity "
<< free_list.capacity() << std::endl;
}
}
void Free(NVTETensor *t, size_t N) {
std::lock_guard<std::mutex> lock(mutex);
for (size_t i = 0; i < N; ++i) {
uintptr_t index = reinterpret_cast<uintptr_t>(t[i]);
if (index == 0) continue;
NVTE_CHECK(index <= memory.size(), "Invalid tensor.");
free_list.push_back(index);
// Clean up
memory[index - 1].clear();
}
if (debug) {
std::cout << "Freed range of" << N << " tensors. Free list size: " << free_list.size()
<< " and capacity " << free_list.capacity() << std::endl;
}
}
Tensor *convertNVTETensor(NVTETensor t) {
uintptr_t index = reinterpret_cast<uintptr_t>(t);
// 1-based indexing to enable 0-initialization of NVTETensor
// to be invalid tensor
static_assert(nullptr == 0);
if (index != 0 && index <= size) {
return &(memory[index - 1]);
}
return nullptr;
}
void setDebug(bool debug) {
std::lock_guard<std::mutex> lock(mutex);
this->debug = debug;
}
private:
TensorAllocator() {
std::lock_guard<std::mutex> lock(mutex);
memory.reserve(MAX_TENSOR_NUM);
}
std::mutex mutex;
std::atomic<size_t> size;
// Allocate at most 20 MB for tensors
// Should be replaced by virtual memory allocation
const size_t MAX_TENSOR_NUM = 20 * 1024 * 1024 / sizeof(Tensor);
std::vector<uintptr_t> free_list;
std::vector<Tensor> memory;
bool debug = false;
};
Tensor *convertNVTETensor(const NVTETensor t) {
return TensorAllocator::instance().convertNVTETensor(t);
}
Tensor *convertNVTETensorCheck(const NVTETensor t) {
Tensor *ptr = TensorAllocator::instance().convertNVTETensor(t);
NVTE_CHECK(ptr != nullptr, "Invalid tensor.");
return ptr;
}
} // namespace transformer_engine
NVTETensor nvte_create_tensor(NVTEScalingMode scaling_mode) {
transformer_engine::Tensor *ret = new transformer_engine::Tensor;
ret->scaling_mode = scaling_mode;
NVTETensor ret = transformer_engine::TensorAllocator::instance().Allocate(scaling_mode);
return ret;
}
void nvte_destroy_tensor(NVTETensor tensor) {
if (tensor == nullptr) return;
auto *t = reinterpret_cast<transformer_engine::Tensor *>(tensor);
delete t;
transformer_engine::TensorAllocator::instance().Free(tensor);
}
void nvte_destroy_tensors(NVTETensor *tensors, size_t N) {
transformer_engine::TensorAllocator::instance().Free(tensors, N);
}
NVTEDType nvte_tensor_type(const NVTETensor tensor) {
if (tensor == nullptr) return kNVTEFloat32;
return static_cast<NVTEDType>(
reinterpret_cast<const transformer_engine::Tensor *>(tensor)->dtype());
auto *t = transformer_engine::convertNVTETensor(tensor);
if (t == nullptr) return kNVTEFloat32;
return static_cast<NVTEDType>(t->dtype());
}
NVTEShape nvte_make_shape(const size_t *data, size_t ndim) {
......@@ -227,23 +346,24 @@ NVTEShape nvte_make_shape(const size_t *data, size_t ndim) {
}
NVTEShape nvte_tensor_shape(const NVTETensor tensor) {
if (tensor == nullptr) {
auto *t = transformer_engine::convertNVTETensor(tensor);
if (t == nullptr) {
NVTE_ERROR("Invalid tensor");
}
// Determine tensor shape depending on tensor format
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
std::vector<size_t> shape = t.shape();
const std::vector<size_t> &shape = t->shape();
return nvte_make_shape(shape.data(), shape.size());
}
NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) {
if (tensor == nullptr) {
auto *t = transformer_engine::convertNVTETensor(tensor);
if (t == nullptr) {
NVTE_ERROR("Invalid tensor");
}
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
return nvte_make_shape(t.columnwise_data.shape.data(), t.columnwise_data.shape.size());
const std::vector<size_t> &shape = t->columnwise_data.shape;
return nvte_make_shape(shape.data(), shape.size());
}
size_t nvte_tensor_ndims(const NVTETensor tensor) { return nvte_tensor_shape(tensor).ndim; }
......@@ -265,82 +385,82 @@ size_t nvte_tensor_numel(const NVTETensor tensor) {
}
size_t nvte_tensor_element_size(const NVTETensor tensor) {
if (tensor == nullptr) return sizeof(float);
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
return transformer_engine::typeToSize(t.dtype());
auto *t = transformer_engine::convertNVTETensor(tensor);
if (t == nullptr) return sizeof(float);
return transformer_engine::typeToSize(t->dtype());
}
void *nvte_tensor_data(const NVTETensor tensor) {
if (tensor == nullptr) return nullptr;
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
return t.data.dptr;
auto *t = transformer_engine::convertNVTETensor(tensor);
if (t == nullptr) return nullptr;
return t->data.dptr;
}
void *nvte_tensor_columnwise_data(const NVTETensor tensor) {
if (tensor == nullptr) return nullptr;
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
return t.columnwise_data.dptr;
auto *t = transformer_engine::convertNVTETensor(tensor);
if (t == nullptr) return nullptr;
return t->columnwise_data.dptr;
}
float *nvte_tensor_amax(const NVTETensor tensor) {
if (tensor == nullptr) return nullptr;
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
NVTE_CHECK(t.amax.dtype == transformer_engine::DType::kFloat32,
auto *t = transformer_engine::convertNVTETensor(tensor);
if (t == nullptr) return nullptr;
NVTE_CHECK(t->amax.dtype == transformer_engine::DType::kFloat32,
"Tensor's amax must have Float32 type!");
return reinterpret_cast<float *>(t.amax.dptr);
return reinterpret_cast<float *>(t->amax.dptr);
}
float *nvte_tensor_scale(const NVTETensor tensor) {
if (tensor == nullptr) return nullptr;
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
NVTE_CHECK(t.scale.dtype == transformer_engine::DType::kFloat32,
auto *t = transformer_engine::convertNVTETensor(tensor);
if (t == nullptr) return nullptr;
NVTE_CHECK(t->scale.dtype == transformer_engine::DType::kFloat32,
"Tensor's scale must have Float32 type!");
return reinterpret_cast<float *>(t.scale.dptr);
return reinterpret_cast<float *>(t->scale.dptr);
}
float *nvte_tensor_scale_inv(const NVTETensor tensor) {
if (tensor == nullptr) return nullptr;
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
return reinterpret_cast<float *>(t.scale_inv.dptr);
auto *t = transformer_engine::convertNVTETensor(tensor);
if (t == nullptr) return nullptr;
return reinterpret_cast<float *>(t->scale_inv.dptr);
}
void *nvte_tensor_columnwise_scale_inv(const NVTETensor tensor) {
if (tensor == nullptr) return nullptr;
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
return t.columnwise_scale_inv.dptr;
auto *t = transformer_engine::convertNVTETensor(tensor);
if (t == nullptr) return nullptr;
return t->columnwise_scale_inv.dptr;
}
NVTEShape nvte_tensor_scale_inv_shape(const NVTETensor tensor) {
if (tensor == nullptr) {
auto *t = transformer_engine::convertNVTETensor(tensor);
if (t == nullptr) {
return nvte_make_shape(nullptr, 0);
}
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
return nvte_make_shape(t.scale_inv.shape.data(), t.scale_inv.shape.size());
return nvte_make_shape(t->scale_inv.shape.data(), t->scale_inv.shape.size());
}
void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name,
const NVTEBasicTensor *param) {
NVTE_CHECK(tensor != nullptr, "Tensor pointer can't be NULL.");
NVTE_CHECK(*tensor != nullptr, "Tensor is not allocated.");
auto &t = *reinterpret_cast<transformer_engine::Tensor *>(*tensor);
auto *t = transformer_engine::convertNVTETensor(*tensor);
NVTE_CHECK(t != nullptr, "Tensor is not allocated.");
switch (param_name) {
case kNVTERowwiseData:
t.data = *param;
t->data = *param;
break;
case kNVTEColumnwiseData:
t.columnwise_data = *param;
t->columnwise_data = *param;
break;
case kNVTEScale:
t.scale = *param;
t->scale = *param;
break;
case kNVTEAmax:
t.amax = *param;
t->amax = *param;
break;
case kNVTERowwiseScaleInv:
t.scale_inv = *param;
t->scale_inv = *param;
break;
case kNVTEColumnwiseScaleInv:
t.columnwise_scale_inv = *param;
t->columnwise_scale_inv = *param;
break;
default:
NVTE_ERROR("Unknown tensor parameter!");
......@@ -351,7 +471,7 @@ NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam p
if (tensor == nullptr) {
return {nullptr, kNVTEFloat32, nvte_make_shape(nullptr, 0)};
}
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
const auto &t = *transformer_engine::convertNVTETensorCheck(tensor);
switch (param_name) {
case kNVTERowwiseData:
return t.data;
......@@ -371,25 +491,27 @@ NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam p
}
NVTEScalingMode nvte_tensor_scaling_mode(const NVTETensor tensor) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
if (tensor == nullptr) {
return NVTE_DELAYED_TENSOR_SCALING;
}
const auto &t = *transformer_engine::convertNVTETensorCheck(tensor);
return t.scaling_mode;
}
void nvte_tensor_pack_create(NVTETensorPack *pack) {
for (int i = 0; i < pack->MAX_SIZE; i++) {
pack->tensors[i] = reinterpret_cast<NVTETensor>(new transformer_engine::Tensor);
pack->tensors[i] =
transformer_engine::TensorAllocator::instance().Allocate(NVTE_DELAYED_TENSOR_SCALING);
}
}
void nvte_tensor_pack_destroy(NVTETensorPack *pack) {
for (int i = 0; i < pack->MAX_SIZE; i++) {
auto *t = reinterpret_cast<transformer_engine::Tensor *>(pack->tensors[i]);
delete t;
}
transformer_engine::TensorAllocator::instance().Free(pack->tensors, pack->MAX_SIZE);
}
void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
if (tensor == nullptr) return;
const auto &t = *transformer_engine::convertNVTETensorCheck(tensor);
// Zero out tensor data if allocated
if (t.data.dptr != nullptr) {
size_t size_in_bytes = nvte_tensor_element_size(tensor) * nvte_tensor_numel(tensor);
......
......@@ -348,15 +348,15 @@ void nvte_cast_transpose(const NVTETensor input, NVTETensor output, cudaStream_t
NVTE_API_CALL(nvte_cast_transpose);
using namespace transformer_engine;
auto noop = Tensor();
transformer_engine::detail::cast_transpose(*reinterpret_cast<const Tensor *>(input), noop,
reinterpret_cast<Tensor *>(output), stream);
transformer_engine::detail::cast_transpose(*convertNVTETensorCheck(input), noop,
convertNVTETensor(output), stream);
}
void nvte_cast_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_with_noop);
using namespace transformer_engine;
transformer_engine::detail::cast_transpose(*reinterpret_cast<const Tensor *>(input),
*reinterpret_cast<const Tensor *>(noop),
reinterpret_cast<Tensor *>(output), stream);
transformer_engine::detail::cast_transpose(*convertNVTETensorCheck(input),
*convertNVTETensorCheck(noop),
convertNVTETensor(output), stream);
}
......@@ -17,6 +17,7 @@
#include "../util/string.h"
#include "../utils.cuh"
#include "cast_transpose.h"
#include "common/common.h"
namespace transformer_engine {
......@@ -1267,9 +1268,8 @@ void nvte_cast_transpose_dbias(const NVTETensor input, NVTETensor output, NVTETe
constexpr const NVTETensor activation_input = nullptr;
cast_transpose_fused<IS_DBIAS, IS_DACT, IS_ACT, ComputeType, Empty, nullptr>(
*reinterpret_cast<const Tensor *>(input), reinterpret_cast<const Tensor *>(activation_input),
reinterpret_cast<Tensor *>(output), reinterpret_cast<Tensor *>(dbias),
reinterpret_cast<Tensor *>(workspace), stream);
*convertNVTETensorCheck(input), convertNVTETensor(activation_input),
convertNVTETensor(output), convertNVTETensor(dbias), convertNVTETensor(workspace), stream);
}
void nvte_cast_transpose_dbias_dgelu(const NVTETensor input, const NVTETensor act_input,
......@@ -1284,9 +1284,9 @@ void nvte_cast_transpose_dbias_dgelu(const NVTETensor input, const NVTETensor ac
constexpr bool IS_ACT = false;
cast_transpose_fused<IS_DBIAS, IS_DACT, IS_ACT, ComputeType, Empty, dgelu<fp32, fp32>>(
*reinterpret_cast<const Tensor *>(input), reinterpret_cast<const Tensor *>(act_input),
reinterpret_cast<Tensor *>(output), reinterpret_cast<Tensor *>(dbias),
reinterpret_cast<Tensor *>(workspace), stream);
*convertNVTETensorCheck(input), convertNVTETensorCheck(act_input),
convertNVTETensorCheck(output), convertNVTETensorCheck(dbias), convertNVTETensor(workspace),
stream);
}
void nvte_cast_transpose_dbias_dsilu(const NVTETensor input, const NVTETensor silu_input,
......@@ -1301,9 +1301,9 @@ void nvte_cast_transpose_dbias_dsilu(const NVTETensor input, const NVTETensor si
constexpr bool IS_ACT = false;
cast_transpose_fused<IS_DBIAS, IS_DACT, IS_ACT, ComputeType, Empty, dsilu<fp32, fp32>>(
*reinterpret_cast<const Tensor *>(input), reinterpret_cast<const Tensor *>(silu_input),
reinterpret_cast<Tensor *>(output), reinterpret_cast<Tensor *>(dbias),
reinterpret_cast<Tensor *>(workspace), stream);
*convertNVTETensorCheck(input), convertNVTETensorCheck(silu_input),
convertNVTETensorCheck(output), convertNVTETensorCheck(dbias), convertNVTETensor(workspace),
stream);
}
void nvte_cast_transpose_dbias_drelu(const NVTETensor input, const NVTETensor relu_input,
......@@ -1318,9 +1318,9 @@ void nvte_cast_transpose_dbias_drelu(const NVTETensor input, const NVTETensor re
constexpr bool IS_ACT = false;
cast_transpose_fused<IS_DBIAS, IS_DACT, IS_ACT, ComputeType, Empty, drelu<fp32, fp32>>(
*reinterpret_cast<const Tensor *>(input), reinterpret_cast<const Tensor *>(relu_input),
reinterpret_cast<Tensor *>(output), reinterpret_cast<Tensor *>(dbias),
reinterpret_cast<Tensor *>(workspace), stream);
*convertNVTETensorCheck(input), convertNVTETensorCheck(relu_input),
convertNVTETensorCheck(output), convertNVTETensorCheck(dbias), convertNVTETensor(workspace),
stream);
}
void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input, const NVTETensor srelu_input,
......@@ -1335,9 +1335,9 @@ void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input, const NVTETensor s
constexpr bool IS_ACT = false;
cast_transpose_fused<IS_DBIAS, IS_DACT, IS_ACT, ComputeType, Empty, dsrelu<fp32, fp32>>(
*reinterpret_cast<const Tensor *>(input), reinterpret_cast<const Tensor *>(srelu_input),
reinterpret_cast<Tensor *>(output), reinterpret_cast<Tensor *>(dbias),
reinterpret_cast<Tensor *>(workspace), stream);
*convertNVTETensorCheck(input), convertNVTETensorCheck(srelu_input),
convertNVTETensorCheck(output), convertNVTETensorCheck(dbias), convertNVTETensor(workspace),
stream);
}
void nvte_cast_transpose_dbias_dqgelu(const NVTETensor input, const NVTETensor qgelu_input,
......@@ -1352,9 +1352,9 @@ void nvte_cast_transpose_dbias_dqgelu(const NVTETensor input, const NVTETensor q
constexpr bool IS_ACT = false;
cast_transpose_fused<IS_DBIAS, IS_DACT, IS_ACT, ComputeType, Empty, dqgelu<fp32, fp32>>(
*reinterpret_cast<const Tensor *>(input), reinterpret_cast<const Tensor *>(qgelu_input),
reinterpret_cast<Tensor *>(output), reinterpret_cast<Tensor *>(dbias),
reinterpret_cast<Tensor *>(workspace), stream);
*convertNVTETensorCheck(input), convertNVTETensorCheck(qgelu_input),
convertNVTETensorCheck(output), convertNVTETensorCheck(dbias), convertNVTETensor(workspace),
stream);
}
void nvte_dgeglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input,
......@@ -1364,8 +1364,8 @@ void nvte_dgeglu_cast_transpose(const NVTETensor input, const NVTETensor gated_a
using namespace transformer_engine::detail;
dgated_act_cast_transpose<ComputeType, Empty, dgelu<fp32, fp32>, gelu<fp32, fp32>>(
*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(gated_act_input),
reinterpret_cast<Tensor *>(output), stream);
*convertNVTETensorCheck(input), *convertNVTETensorCheck(gated_act_input),
convertNVTETensorCheck(output), stream);
}
void nvte_dswiglu_cast_transpose(const NVTETensor input, const NVTETensor swiglu_input,
......@@ -1375,8 +1375,8 @@ void nvte_dswiglu_cast_transpose(const NVTETensor input, const NVTETensor swiglu
using namespace transformer_engine::detail;
dgated_act_cast_transpose<ComputeType, Empty, dsilu<fp32, fp32>, silu<fp32, fp32>>(
*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(swiglu_input),
reinterpret_cast<Tensor *>(output), stream);
*convertNVTETensorCheck(input), *convertNVTETensorCheck(swiglu_input),
convertNVTETensorCheck(output), stream);
}
void nvte_dreglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input,
......@@ -1386,8 +1386,8 @@ void nvte_dreglu_cast_transpose(const NVTETensor input, const NVTETensor gated_a
using namespace transformer_engine::detail;
dgated_act_cast_transpose<ComputeType, Empty, drelu<fp32, fp32>, relu<fp32, fp32>>(
*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(gated_act_input),
reinterpret_cast<Tensor *>(output), stream);
*convertNVTETensorCheck(input), *convertNVTETensorCheck(gated_act_input),
convertNVTETensorCheck(output), stream);
}
void nvte_dsreglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input,
......@@ -1397,8 +1397,8 @@ void nvte_dsreglu_cast_transpose(const NVTETensor input, const NVTETensor gated_
using namespace transformer_engine::detail;
dgated_act_cast_transpose<ComputeType, Empty, dsrelu<fp32, fp32>, srelu<fp32, fp32>>(
*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(gated_act_input),
reinterpret_cast<Tensor *>(output), stream);
*convertNVTETensorCheck(input), *convertNVTETensorCheck(gated_act_input),
convertNVTETensorCheck(output), stream);
}
void nvte_dqgeglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input,
......@@ -1408,6 +1408,6 @@ void nvte_dqgeglu_cast_transpose(const NVTETensor input, const NVTETensor gated_
using namespace transformer_engine::detail;
dgated_act_cast_transpose<ComputeType, Empty, dqgelu<fp32, fp32>, qgelu<fp32, fp32>>(
*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(gated_act_input),
reinterpret_cast<Tensor *>(output), stream);
*convertNVTETensorCheck(input), *convertNVTETensorCheck(gated_act_input),
convertNVTETensorCheck(output), stream);
}
......@@ -334,8 +334,8 @@ void nvte_multi_cast_transpose(size_t num_tensors, const NVTETensor* input_list,
using namespace transformer_engine;
std::vector<Tensor*> input_list_, output_list_;
for (size_t i = 0; i < num_tensors; ++i) {
input_list_.push_back(reinterpret_cast<Tensor*>(const_cast<NVTETensor&>(input_list[i])));
output_list_.push_back(reinterpret_cast<Tensor*>(output_list[i]));
input_list_.push_back(convertNVTETensorCheck(input_list[i]));
output_list_.push_back(convertNVTETensorCheck(output_list[i]));
}
multi_cast_transpose(input_list_, output_list_, stream);
}
......@@ -288,14 +288,13 @@ void nvte_transpose(const NVTETensor input, NVTETensor output, cudaStream_t stre
NVTE_API_CALL(nvte_transpose);
using namespace transformer_engine;
auto noop = Tensor();
transpose(*reinterpret_cast<const Tensor *>(input), noop, reinterpret_cast<Tensor *>(output),
stream);
transpose(*convertNVTETensorCheck(input), noop, convertNVTETensor(output), stream);
}
void nvte_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_transpose_with_noop);
using namespace transformer_engine;
transpose(*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(noop),
reinterpret_cast<Tensor *>(output), stream);
transpose(*convertNVTETensorCheck(input), *convertNVTETensorCheck(noop),
convertNVTETensor(output), stream);
}
......@@ -495,7 +495,6 @@ void nvte_fp8_transpose_dbias(const NVTETensor input, NVTETensor transposed_outp
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_fp8_transpose_dbias);
using namespace transformer_engine;
fp8_transpose_dbias(
*reinterpret_cast<const Tensor *>(input), reinterpret_cast<Tensor *>(transposed_output),
reinterpret_cast<Tensor *>(dbias), reinterpret_cast<Tensor *>(workspace), stream);
fp8_transpose_dbias(*convertNVTETensorCheck(input), convertNVTETensor(transposed_output),
convertNVTETensor(dbias), convertNVTETensor(workspace), stream);
}
......@@ -154,6 +154,5 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activat
void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_dequantize);
using namespace transformer_engine;
detail::dequantize_helper(*reinterpret_cast<const Tensor *>(input),
reinterpret_cast<Tensor *>(output), stream);
detail::dequantize_helper(*convertNVTETensorCheck(input), convertNVTETensorCheck(output), stream);
}
......@@ -1057,10 +1057,9 @@ void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input,
cudaStream_t stream) {
using namespace gated_kernels;
Tensor grad_empty_tensor;
const Tensor &grad_tensor =
IS_DGATED ? *(reinterpret_cast<const Tensor *>(grad)) : grad_empty_tensor;
const Tensor gated_input_tensor = *reinterpret_cast<const Tensor *>(gated_input);
Tensor *output_tensor = reinterpret_cast<Tensor *>(output);
const Tensor &grad_tensor = IS_DGATED ? *(convertNVTETensorCheck(grad)) : grad_empty_tensor;
const Tensor gated_input_tensor = *convertNVTETensorCheck(gated_input);
Tensor *output_tensor = convertNVTETensorCheck(output);
if (is_supported_by_CC_100()) {
quantize_gated<IS_DGATED, ParamOP, ActOP, DActOP>(grad_tensor, gated_input_tensor,
......
......@@ -1222,23 +1222,23 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
const Tensor *activation_input_tensor;
if constexpr (IS_DBIAS || IS_DACT) {
// backward - input is incoming gradient
input_tensor = reinterpret_cast<const Tensor *>(grad);
activation_input_tensor = reinterpret_cast<const Tensor *>(input);
input_tensor = convertNVTETensorCheck(grad);
activation_input_tensor = convertNVTETensor(input);
} else {
// forward = input is activation input
input_tensor = reinterpret_cast<const Tensor *>(input);
input_tensor = convertNVTETensorCheck(input);
activation_input_tensor = nullptr;
}
auto output_tensor = reinterpret_cast<Tensor *>(output);
auto dbias_tensor = reinterpret_cast<Tensor *>(dbias);
auto workspace_tensor = reinterpret_cast<Tensor *>(workspace);
auto output_tensor = convertNVTETensorCheck(output);
auto dbias_tensor = convertNVTETensor(dbias);
auto workspace_tensor = convertNVTETensor(workspace);
const QuantizationConfig *quant_config_cpp =
reinterpret_cast<const QuantizationConfig *>(quant_config);
// extract noop tensor from quant_config_cpp if it's not null
const NVTETensor noop = quant_config_cpp ? quant_config_cpp->noop_tensor : nullptr;
const auto noop_tensor = noop != nullptr ? *(reinterpret_cast<const Tensor *>(noop)) : Tensor();
const auto noop_tensor = noop != nullptr ? *(convertNVTETensorCheck(noop)) : Tensor();
switch (output_tensor->scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: {
......
......@@ -211,8 +211,8 @@ void nvte_multi_padding(size_t num_tensors, const NVTETensor* input_list, NVTETe
std::vector<Tensor*> input_list_, output_list_;
std::vector<int> padded_num_rows_list_;
for (size_t i = 0; i < num_tensors; ++i) {
input_list_.push_back(reinterpret_cast<Tensor*>(const_cast<NVTETensor&>(input_list[i])));
output_list_.push_back(reinterpret_cast<Tensor*>(output_list[i]));
input_list_.push_back(convertNVTETensorCheck(input_list[i]));
output_list_.push_back(convertNVTETensorCheck(output_list[i]));
padded_num_rows_list_.push_back(padded_num_rows_list[i]);
}
multi_padding(input_list_, output_list_, padded_num_rows_list_, stream);
......
......@@ -6,6 +6,7 @@
#include "extensions.h"
#include "transformer_engine/fused_attn.h"
#include "transformer_engine/transformer_engine.h"
namespace transformer_engine {
namespace jax {
......@@ -40,33 +41,46 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t
// all backends need softmax but expect different shapes/dtypes
// start with the max512 sequence length softmax shape/dtype and correct later
tensor_pack->size = 1;
Tensor *softmax_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[0]);
softmax_aux->data.dptr = softmax_buf;
softmax_aux->data.shape =
std::vector<size_t>{input_batch, attn_heads, q_max_seqlen, kv_max_seqlen};
softmax_aux->data.dtype = dtype;
NVTETensor &softmax_aux = tensor_pack->tensors[0];
NVTEBasicTensor softmax_aux_data;
softmax_aux_data.data_ptr = softmax_buf;
softmax_aux_data.shape.ndim = 4;
softmax_aux_data.shape.data[0] = input_batch;
softmax_aux_data.shape.data[1] = attn_heads;
softmax_aux_data.shape.data[2] = q_max_seqlen;
softmax_aux_data.shape.data[3] = kv_max_seqlen;
softmax_aux_data.dtype = static_cast<NVTEDType>(dtype);
// arbitrary sequence length backend needs the RNG state and a different shape/dtype softmax
if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
tensor_pack->size = 2;
Tensor *rng_state_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[1]);
rng_state_aux->data.dptr = rng_state_buf;
rng_state_aux->data.shape = std::vector<size_t>{2};
rng_state_aux->data.dtype = DType::kInt64;
NVTETensor &rng_state_aux = tensor_pack->tensors[1];
NVTEBasicTensor rng_state_aux_data;
rng_state_aux_data.data_ptr = rng_state_buf;
rng_state_aux_data.shape = {};
rng_state_aux_data.shape.ndim = 2;
rng_state_aux_data.dtype = static_cast<NVTEDType>(DType::kInt64);
nvte_set_tensor_param(&rng_state_aux, kNVTERowwiseData, &rng_state_aux_data);
// correct softmax shape/dtype
softmax_aux->data.shape.at(3) = 1; // {B,H,Qs,Ks} -> {B,H,Qs,1}
softmax_aux->data.dtype = DType::kFloat32;
softmax_aux_data.shape.data[3] = 1; // {B,H,Qs,Ks} -> {B,H,Qs,1}
softmax_aux_data.dtype = static_cast<NVTEDType>(DType::kFloat32);
// include bias if enabled
if (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS && bias_type != NVTE_Bias_Type::NVTE_ALIBI) {
tensor_pack->size = 3;
Tensor *bias_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[2]);
bias_aux->data.dptr = bias_buf;
bias_aux->data.shape =
std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen};
bias_aux->data.dtype = dtype;
NVTETensor &bias_aux = tensor_pack->tensors[2];
NVTEBasicTensor bias_aux_data;
bias_aux_data.data_ptr = bias_buf;
bias_aux_data.shape.ndim = 4;
bias_aux_data.shape.data[0] = bias_batch;
bias_aux_data.shape.data[1] = bias_heads;
bias_aux_data.shape.data[2] = q_max_seqlen;
bias_aux_data.shape.data[3] = kv_max_seqlen;
bias_aux_data.dtype = static_cast<NVTEDType>(dtype);
nvte_set_tensor_param(&bias_aux, kNVTERowwiseData, &bias_aux_data);
}
}
nvte_set_tensor_param(&softmax_aux, kNVTERowwiseData, &softmax_aux_data);
}
/*
......@@ -93,9 +107,11 @@ void PrepareFusedAttnBackwardAuxTensors(NVTETensorPack *tensor_pack, const size_
// correct softmax shape for max512 sequence length kernel
if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
Tensor *softmax_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[0]);
softmax_aux->data.shape.at(3) = kv_max_seqlen; // {B,H,Qs,1} -> {B,H,Qs,Ks}
softmax_aux->data.dtype = dtype;
NVTEBasicTensor softmax_aux_data =
nvte_get_tensor_param(tensor_pack->tensors[0], kNVTERowwiseData);
softmax_aux_data.shape.data[3] = kv_max_seqlen; // {B,H,Qs,1} -> {B,H,Qs,Ks}
softmax_aux_data.dtype = static_cast<NVTEDType>(dtype);
nvte_set_tensor_param(&(tensor_pack->tensors[0]), kNVTERowwiseData, &softmax_aux_data);
}
}
......
......@@ -20,6 +20,20 @@ std::vector<size_t> getTensorShape(at::Tensor t) {
return shape;
}
NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape) {
NVTEShape ret;
ret.ndim = torch_shape.size();
constexpr int max_dimensions = sizeof(ret.data) / sizeof(size_t);
NVTE_CHECK(ret.ndim < max_dimensions,
"Torch tensor has too many dimensions. Max supported: ", max_dimensions, " and got ",
ret.ndim, ".");
for (size_t i = 0; i < ret.ndim; ++i) {
const auto& v = torch_shape[i];
ret.data[i] = static_cast<size_t>(v);
}
return ret;
}
std::unique_ptr<Quantizer> convert_quantizer(py::handle quantizer) {
init_extension();
if (quantizer.is_none()) {
......
......@@ -351,6 +351,7 @@ std::vector<size_t> convertShape(const NVTEShape& shape);
int roundup(const int value, const int multiple);
NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape);
} // namespace transformer_engine::pytorch
namespace std {
......
......@@ -9,8 +9,8 @@
#include <string>
#include "common/common.h"
#include "extensions.h"
#include "transformer_engine/transformer_engine.h"
namespace transformer_engine::pytorch {
......@@ -34,30 +34,35 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor& amax_reductio
const std::string& amax_compute_algo,
DType fp8_dtype, float margin) {
size_t num_tensors = amax_histories.size();
std::vector<Tensor> t_amax_histories(num_tensors);
std::vector<Tensor> t_scales(num_tensors);
std::vector<NVTETensor> te_amax_histories(num_tensors);
std::vector<NVTETensor> te_scales(num_tensors);
std::vector<NVTETensor> te_amax_histories;
std::vector<NVTETensor> te_scales;
te_amax_histories.reserve(num_tensors);
te_scales.reserve(num_tensors);
for (size_t i = 0; i < num_tensors; i++) {
t_amax_histories[i].data.dptr = amax_histories[i].data_ptr();
auto amax_sizes = amax_histories[i].sizes().vec();
std::vector<size_t> amax_shape{amax_sizes.begin(), amax_sizes.end()};
t_amax_histories[i].data.shape = amax_shape;
t_amax_histories[i].data.dtype = DType::kFloat32;
t_scales[i].data.dptr = scales[i].data_ptr();
auto scale_sizes = scales[i].sizes().vec();
std::vector<size_t> scale_shape{scale_sizes.begin(), scale_sizes.end()};
t_scales[i].data.shape = scale_shape;
t_scales[i].data.dtype = DType::kFloat32;
te_amax_histories[i] = reinterpret_cast<NVTETensor>(&t_amax_histories[i]);
te_scales[i] = reinterpret_cast<NVTETensor>(&t_scales[i]);
te_amax_histories.push_back(nvte_create_tensor(NVTE_DELAYED_TENSOR_SCALING));
NVTETensor& amax_history = te_amax_histories.back();
NVTEShape amax_shape = convertTorchShape(amax_histories[i].sizes());
NVTEBasicTensor amax_history_data = {amax_histories[i].data_ptr(),
static_cast<NVTEDType>(DType::kFloat32), amax_shape};
nvte_set_tensor_param(&amax_history, kNVTERowwiseData, &amax_history_data);
te_scales.push_back(nvte_create_tensor(NVTE_DELAYED_TENSOR_SCALING));
NVTETensor& scale = te_scales.back();
NVTEShape scale_shape = convertTorchShape(scales[i].sizes());
NVTEBasicTensor scale_data = {scales[i].data_ptr(), static_cast<NVTEDType>(DType::kFloat32),
scale_shape};
nvte_set_tensor_param(&scale, kNVTERowwiseData, &scale_data);
}
nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction(
makeTransformerEngineTensor(amax_reduction_buffer).data(), te_amax_histories, te_scales,
amax_compute_algo.c_str(), static_cast<NVTEDType>(fp8_dtype), margin,
at::cuda::getCurrentCUDAStream());
for (auto& t : te_amax_histories) {
nvte_destroy_tensor(t);
}
for (auto& t : te_scales) {
nvte_destroy_tensor(t);
}
}
} // namespace transformer_engine::pytorch
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment