Unverified Commit 4292653c authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Avoid memory allocations and deallocations when creating NVTETensor (#1813)



* Changed the Tensor allocation strategy
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fixes
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Disable debug flag
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix the double free error
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fixed pyTorch recipe extension
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Hide TensorAllocator and fix the usage in LayerNorm
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Cleaning
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix permutation
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

---------
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 41909dc8
...@@ -90,11 +90,12 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens ...@@ -90,11 +90,12 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
// Compute FP8 transpose if required // Compute FP8 transpose if required
if (z->has_columnwise_data() && is_tensor_scaling(z->scaling_mode)) { if (z->has_columnwise_data() && is_tensor_scaling(z->scaling_mode)) {
Tensor transpose_data; NVTETensor transpose_data = nvte_create_tensor(z->scaling_mode);
transpose_data.data = z->columnwise_data; auto *t = convertNVTETensor(transpose_data);
transpose_data.scaling_mode = z->scaling_mode; t->data = z->columnwise_data;
nvte_transpose(reinterpret_cast<NVTETensor>(z), reinterpret_cast<NVTETensor>(&transpose_data),
stream); nvte_transpose(static_cast<NVTETensor>(*z), transpose_data, stream);
nvte_destroy_tensor(transpose_data);
} }
return; return;
...@@ -171,10 +172,9 @@ void nvte_rmsnorm_fwd(const NVTETensor x, // Nxhidden_size ...@@ -171,10 +172,9 @@ void nvte_rmsnorm_fwd(const NVTETensor x, // Nxhidden_size
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_rmsnorm_fwd); NVTE_API_CALL(nvte_rmsnorm_fwd);
using namespace transformer_engine; using namespace transformer_engine;
rmsnorm_fwd(*reinterpret_cast<const Tensor *>(x), *reinterpret_cast<const Tensor *>(gamma), rmsnorm_fwd(*convertNVTETensorCheck(x), *convertNVTETensorCheck(gamma), epsilon,
epsilon, reinterpret_cast<Tensor *>(z), reinterpret_cast<Tensor *>(rsigma), convertNVTETensor(z), convertNVTETensor(rsigma), convertNVTETensor(workspace),
reinterpret_cast<Tensor *>(workspace), multiprocessorCount, zero_centered_gamma, multiprocessorCount, zero_centered_gamma, stream);
stream);
} }
void nvte_rmsnorm_bwd(const NVTETensor dz, // Nxhidden_size void nvte_rmsnorm_bwd(const NVTETensor dz, // Nxhidden_size
...@@ -186,9 +186,8 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, // Nxhidden_size ...@@ -186,9 +186,8 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, // Nxhidden_size
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_rmsnorm_bwd); NVTE_API_CALL(nvte_rmsnorm_bwd);
using namespace transformer_engine; using namespace transformer_engine;
rmsnorm_bwd(*reinterpret_cast<const Tensor *>(dz), *reinterpret_cast<const Tensor *>(x), rmsnorm_bwd(*convertNVTETensorCheck(dz), *convertNVTETensorCheck(x),
*reinterpret_cast<const Tensor *>(rsigma), *reinterpret_cast<const Tensor *>(gamma), *convertNVTETensorCheck(rsigma), *convertNVTETensorCheck(gamma),
reinterpret_cast<Tensor *>(dx), reinterpret_cast<Tensor *>(dgamma), convertNVTETensor(dx), convertNVTETensor(dgamma), convertNVTETensor(workspace),
reinterpret_cast<Tensor *>(workspace), multiprocessorCount, zero_centered_gamma, multiprocessorCount, zero_centered_gamma, stream);
stream);
} }
...@@ -318,22 +318,16 @@ void nvte_permute(const NVTETensor input, NVTETensor output, const NVTETensor so ...@@ -318,22 +318,16 @@ void nvte_permute(const NVTETensor input, NVTETensor output, const NVTETensor so
NVTETensor row_id_map, const NVTETensor prob, NVTETensor prob_grad, NVTETensor row_id_map, const NVTETensor prob, NVTETensor prob_grad,
const NVTETensor input_fwd, const int num_rows, const int topK, const NVTETensor input_fwd, const int num_rows, const int topK,
const int num_cols, const int num_out_tokens, cudaStream_t stream) { const int num_cols, const int num_out_tokens, cudaStream_t stream) {
using namespace transformer_engine;
NVTE_API_CALL(nvte_permute); NVTE_API_CALL(nvte_permute);
const transformer_engine::Tensor *input_cu = const Tensor *input_cu = convertNVTETensorCheck(input);
reinterpret_cast<const transformer_engine::Tensor *>(input); const Tensor *output_cu = convertNVTETensorCheck(output);
const transformer_engine::Tensor *output_cu = const Tensor *sorted_row_id_cu = convertNVTETensorCheck(sorted_row_id);
reinterpret_cast<const transformer_engine::Tensor *>(output); const Tensor *row_id_map_cu = convertNVTETensorCheck(row_id_map);
const transformer_engine::Tensor *sorted_row_id_cu = const Tensor *prob_cu = convertNVTETensorCheck(prob);
reinterpret_cast<const transformer_engine::Tensor *>(sorted_row_id); const Tensor *prob_grad_cu = convertNVTETensorCheck(prob_grad);
const transformer_engine::Tensor *row_id_map_cu = const Tensor *input_fwd_cu = convertNVTETensorCheck(input_fwd);
reinterpret_cast<const transformer_engine::Tensor *>(row_id_map);
const transformer_engine::Tensor *prob_cu =
reinterpret_cast<const transformer_engine::Tensor *>(prob);
const transformer_engine::Tensor *prob_grad_cu =
reinterpret_cast<const transformer_engine::Tensor *>(prob_grad);
const transformer_engine::Tensor *input_fwd_cu =
reinterpret_cast<const transformer_engine::Tensor *>(input_fwd);
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
input_cu->data.dtype, T, input_cu->data.dtype, T,
...@@ -350,16 +344,13 @@ void nvte_permute(const NVTETensor input, NVTETensor output, const NVTETensor so ...@@ -350,16 +344,13 @@ void nvte_permute(const NVTETensor input, NVTETensor output, const NVTETensor so
void nvte_unpermute(const NVTETensor input, NVTETensor output, NVTETensor row_id_map, void nvte_unpermute(const NVTETensor input, NVTETensor output, NVTETensor row_id_map,
const NVTETensor prob, const int num_rows, const int topK, const int num_cols, const NVTETensor prob, const int num_rows, const int topK, const int num_cols,
cudaStream_t stream) { cudaStream_t stream) {
using namespace transformer_engine;
NVTE_API_CALL(nvte_unpermute); NVTE_API_CALL(nvte_unpermute);
const transformer_engine::Tensor *input_cu = const Tensor *input_cu = convertNVTETensorCheck(input);
reinterpret_cast<const transformer_engine::Tensor *>(input); const Tensor *output_cu = convertNVTETensorCheck(output);
const transformer_engine::Tensor *output_cu = const Tensor *row_id_map_cu = convertNVTETensorCheck(row_id_map);
reinterpret_cast<const transformer_engine::Tensor *>(output); const Tensor *prob_cu = convertNVTETensorCheck(prob);
const transformer_engine::Tensor *row_id_map_cu =
reinterpret_cast<const transformer_engine::Tensor *>(row_id_map);
const transformer_engine::Tensor *prob_cu =
reinterpret_cast<const transformer_engine::Tensor *>(prob);
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
input_cu->data.dtype, T, input_cu->data.dtype, T,
......
...@@ -108,7 +108,7 @@ void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaSt ...@@ -108,7 +108,7 @@ void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaSt
// Check input tensor // Check input tensor
NVTE_CHECK(input_ != nullptr, "Invalid input tensor (got NULL)"); NVTE_CHECK(input_ != nullptr, "Invalid input tensor (got NULL)");
const auto &input = *reinterpret_cast<const Tensor *>(input_); const auto &input = *convertNVTETensorCheck(input_);
NVTE_CHECK(input.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, NVTE_CHECK(input.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Input tensor for amax computation must unquantized, " "Input tensor for amax computation must unquantized, "
"but got scaling_mode=", "but got scaling_mode=",
...@@ -121,7 +121,7 @@ void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaSt ...@@ -121,7 +121,7 @@ void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaSt
// Check output tensor // Check output tensor
NVTE_CHECK(output_ != nullptr, "Invalid output tensor (got NULL)"); NVTE_CHECK(output_ != nullptr, "Invalid output tensor (got NULL)");
auto &output = *reinterpret_cast<Tensor *>(output_); auto &output = *convertNVTETensorCheck(output_);
NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Output tensor for amax computation must be FP8 tensor with per-tensor scaling, " "Output tensor for amax computation must be FP8 tensor with per-tensor scaling, "
"but got scaling_mode=", "but got scaling_mode=",
...@@ -166,7 +166,7 @@ void nvte_compute_scale_from_amax(NVTETensor output_, const NVTEQuantizationConf ...@@ -166,7 +166,7 @@ void nvte_compute_scale_from_amax(NVTETensor output_, const NVTEQuantizationConf
// Check output tensor // Check output tensor
NVTE_CHECK(output_ != nullptr, "Invalid output tensor (got NULL)"); NVTE_CHECK(output_ != nullptr, "Invalid output tensor (got NULL)");
auto &output = *reinterpret_cast<Tensor *>(output_); auto &output = *convertNVTETensorCheck(output_);
NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Tensor must be FP8 tensor with per-tensor scaling, " "Tensor must be FP8 tensor with per-tensor scaling, "
"but got scaling_mode=", "but got scaling_mode=",
......
...@@ -397,9 +397,9 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update( ...@@ -397,9 +397,9 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update(
NVTE_API_CALL(nvte_delayed_scaling_recipe_amax_and_scale_update); NVTE_API_CALL(nvte_delayed_scaling_recipe_amax_and_scale_update);
using namespace transformer_engine; using namespace transformer_engine;
delayed_scaling_recipe::amax_and_scale_update( delayed_scaling_recipe::amax_and_scale_update(
*reinterpret_cast<const Tensor*>(amax_history), *reinterpret_cast<const Tensor*>(scale), *convertNVTETensorCheck(amax_history), *convertNVTETensorCheck(scale),
reinterpret_cast<Tensor*>(updated_amax_history), reinterpret_cast<Tensor*>(updated_scale), convertNVTETensor(updated_amax_history), convertNVTETensor(updated_scale), amax_compute_algo,
amax_compute_algo, static_cast<DType>(fp8_dtype), margin, stream); static_cast<DType>(fp8_dtype), margin, stream);
} }
void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction(
...@@ -411,10 +411,10 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( ...@@ -411,10 +411,10 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction(
size_t num_tensors = amax_histories.size(); size_t num_tensors = amax_histories.size();
std::vector<Tensor*> t_amax_histories, t_scales; std::vector<Tensor*> t_amax_histories, t_scales;
for (size_t i = 0; i < num_tensors; i++) { for (size_t i = 0; i < num_tensors; i++) {
t_amax_histories.push_back(reinterpret_cast<Tensor*>(amax_histories[i])); t_amax_histories.push_back(convertNVTETensor(amax_histories[i]));
t_scales.push_back(reinterpret_cast<Tensor*>(scales[i])); t_scales.push_back(convertNVTETensor(scales[i]));
} }
delayed_scaling_recipe::amax_and_scale_update_after_reduction( delayed_scaling_recipe::amax_and_scale_update_after_reduction(
*reinterpret_cast<const Tensor*>(amax_reduction_buffer), t_amax_histories, t_scales, *convertNVTETensorCheck(amax_reduction_buffer), t_amax_histories, t_scales, amax_compute_algo,
amax_compute_algo, static_cast<DType>(fp8_dtype), margin, stream); static_cast<DType>(fp8_dtype), margin, stream);
} }
...@@ -227,8 +227,8 @@ void nvte_fp8_block_scaling_compute_partial_amax(const NVTETensor inp, NVTETenso ...@@ -227,8 +227,8 @@ void nvte_fp8_block_scaling_compute_partial_amax(const NVTETensor inp, NVTETenso
NVTE_API_CALL(nvte_fp8_block_scaling_compute_partial_amax); NVTE_API_CALL(nvte_fp8_block_scaling_compute_partial_amax);
using namespace transformer_engine; using namespace transformer_engine;
fp8_block_scaling_recipe::fp8_block_scaling_compute_partial_amax( fp8_block_scaling_recipe::fp8_block_scaling_compute_partial_amax(
*reinterpret_cast<const Tensor *>(inp), *reinterpret_cast<Tensor *>(amax), h, w, *convertNVTETensorCheck(inp), *convertNVTETensorCheck(amax), h, w, amax_stride_h,
amax_stride_h, amax_stride_w, start_offset, block_len, stream); amax_stride_w, start_offset, block_len, stream);
} }
void nvte_fp8_block_scaling_partial_cast(const NVTETensor inp, NVTETensor out, void nvte_fp8_block_scaling_partial_cast(const NVTETensor inp, NVTETensor out,
...@@ -239,7 +239,7 @@ void nvte_fp8_block_scaling_partial_cast(const NVTETensor inp, NVTETensor out, ...@@ -239,7 +239,7 @@ void nvte_fp8_block_scaling_partial_cast(const NVTETensor inp, NVTETensor out,
NVTE_API_CALL(nvte_fp8_block_scaling_partial_cast); NVTE_API_CALL(nvte_fp8_block_scaling_partial_cast);
using namespace transformer_engine; using namespace transformer_engine;
fp8_block_scaling_recipe::fp8_block_scaling_partial_cast( fp8_block_scaling_recipe::fp8_block_scaling_partial_cast(
*reinterpret_cast<const Tensor *>(inp), *reinterpret_cast<Tensor *>(out), *convertNVTETensorCheck(inp), *convertNVTETensorCheck(out), *convertNVTETensorCheck(scale), h,
*reinterpret_cast<const Tensor *>(scale), h, w, scale_stride_h, scale_stride_w, start_offset, w, scale_stride_h, scale_stride_w, start_offset, block_len, static_cast<DType>(out_dtype),
block_len, static_cast<DType>(out_dtype), stream); stream);
} }
...@@ -333,6 +333,5 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s ...@@ -333,6 +333,5 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream) { void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_swizzle_scaling_factors); NVTE_API_CALL(nvte_swizzle_scaling_factors);
using namespace transformer_engine; using namespace transformer_engine;
swizzle_scaling_factors(reinterpret_cast<const Tensor*>(input), reinterpret_cast<Tensor*>(output), swizzle_scaling_factors(convertNVTETensorCheck(input), convertNVTETensorCheck(output), stream);
stream);
} }
...@@ -6,11 +6,15 @@ ...@@ -6,11 +6,15 @@
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
#include <atomic>
#include <climits>
#include <cstring> #include <cstring>
#include <iostream> #include <iostream>
#include <mutex>
#include "common.h" #include "common.h"
#include "common/util/cuda_runtime.h" #include "common/util/cuda_runtime.h"
#include "common/util/logging.h"
namespace transformer_engine { namespace transformer_engine {
...@@ -192,24 +196,139 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt ...@@ -192,24 +196,139 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
CheckScaleTensorShape(t, name); CheckScaleTensorShape(t, name);
} }
class TensorAllocator {
public:
static TensorAllocator &instance() {
static TensorAllocator allocator;
return allocator;
}
~TensorAllocator() {}
NVTETensor Allocate(NVTEScalingMode mode) {
std::lock_guard<std::mutex> lock(mutex);
if (!free_list.empty()) {
uintptr_t index = free_list.back();
NVTETensor ret = reinterpret_cast<NVTETensor>(index);
free_list.pop_back();
if (debug) {
std::cout << "Allocated " << index
<< " from free list. Free list size: " << free_list.size() << " and capacity "
<< free_list.capacity() << std::endl;
}
// 1-based indexing
memory[index - 1].scaling_mode = mode;
return ret;
}
if (memory.size() < memory.capacity()) {
memory.emplace_back();
Tensor &t = memory.back();
size = memory.size();
// 1-based indexing
uintptr_t index = memory.size();
if (debug) {
std::cout << "Allocated " << index << ". Memory size: " << memory.size() << " and capacity "
<< memory.capacity() << std::endl;
}
t.scaling_mode = mode;
t.nvte_tensor = reinterpret_cast<NVTETensor>(index);
return reinterpret_cast<NVTETensor>(index);
}
NVTE_ERROR("Cannot allocate a new NVTETensor. Maximum number of tensors reached: ",
MAX_TENSOR_NUM, ". There is probably a memory leak in your application.");
}
void Free(NVTETensor t) {
std::lock_guard<std::mutex> lock(mutex);
uintptr_t index = reinterpret_cast<uintptr_t>(t);
if (index == 0) return;
NVTE_CHECK(index <= memory.size(), "Invalid tensor.");
free_list.push_back(index);
// Clean up
memory[index - 1].clear();
if (debug) {
std::cout << "Freed " << index << ". Free list size: " << free_list.size() << " and capacity "
<< free_list.capacity() << std::endl;
}
}
void Free(NVTETensor *t, size_t N) {
std::lock_guard<std::mutex> lock(mutex);
for (size_t i = 0; i < N; ++i) {
uintptr_t index = reinterpret_cast<uintptr_t>(t[i]);
if (index == 0) continue;
NVTE_CHECK(index <= memory.size(), "Invalid tensor.");
free_list.push_back(index);
// Clean up
memory[index - 1].clear();
}
if (debug) {
std::cout << "Freed range of" << N << " tensors. Free list size: " << free_list.size()
<< " and capacity " << free_list.capacity() << std::endl;
}
}
Tensor *convertNVTETensor(NVTETensor t) {
uintptr_t index = reinterpret_cast<uintptr_t>(t);
// 1-based indexing to enable 0-initialization of NVTETensor
// to be invalid tensor
static_assert(nullptr == 0);
if (index != 0 && index <= size) {
return &(memory[index - 1]);
}
return nullptr;
}
void setDebug(bool debug) {
std::lock_guard<std::mutex> lock(mutex);
this->debug = debug;
}
private:
TensorAllocator() {
std::lock_guard<std::mutex> lock(mutex);
memory.reserve(MAX_TENSOR_NUM);
}
std::mutex mutex;
std::atomic<size_t> size;
// Allocate at most 20 MB for tensors
// Should be replaced by virtual memory allocation
const size_t MAX_TENSOR_NUM = 20 * 1024 * 1024 / sizeof(Tensor);
std::vector<uintptr_t> free_list;
std::vector<Tensor> memory;
bool debug = false;
};
Tensor *convertNVTETensor(const NVTETensor t) {
return TensorAllocator::instance().convertNVTETensor(t);
}
Tensor *convertNVTETensorCheck(const NVTETensor t) {
Tensor *ptr = TensorAllocator::instance().convertNVTETensor(t);
NVTE_CHECK(ptr != nullptr, "Invalid tensor.");
return ptr;
}
} // namespace transformer_engine } // namespace transformer_engine
NVTETensor nvte_create_tensor(NVTEScalingMode scaling_mode) { NVTETensor nvte_create_tensor(NVTEScalingMode scaling_mode) {
transformer_engine::Tensor *ret = new transformer_engine::Tensor; NVTETensor ret = transformer_engine::TensorAllocator::instance().Allocate(scaling_mode);
ret->scaling_mode = scaling_mode;
return ret; return ret;
} }
void nvte_destroy_tensor(NVTETensor tensor) { void nvte_destroy_tensor(NVTETensor tensor) {
if (tensor == nullptr) return; transformer_engine::TensorAllocator::instance().Free(tensor);
auto *t = reinterpret_cast<transformer_engine::Tensor *>(tensor); }
delete t;
void nvte_destroy_tensors(NVTETensor *tensors, size_t N) {
transformer_engine::TensorAllocator::instance().Free(tensors, N);
} }
NVTEDType nvte_tensor_type(const NVTETensor tensor) { NVTEDType nvte_tensor_type(const NVTETensor tensor) {
if (tensor == nullptr) return kNVTEFloat32; auto *t = transformer_engine::convertNVTETensor(tensor);
return static_cast<NVTEDType>( if (t == nullptr) return kNVTEFloat32;
reinterpret_cast<const transformer_engine::Tensor *>(tensor)->dtype()); return static_cast<NVTEDType>(t->dtype());
} }
NVTEShape nvte_make_shape(const size_t *data, size_t ndim) { NVTEShape nvte_make_shape(const size_t *data, size_t ndim) {
...@@ -227,23 +346,24 @@ NVTEShape nvte_make_shape(const size_t *data, size_t ndim) { ...@@ -227,23 +346,24 @@ NVTEShape nvte_make_shape(const size_t *data, size_t ndim) {
} }
NVTEShape nvte_tensor_shape(const NVTETensor tensor) { NVTEShape nvte_tensor_shape(const NVTETensor tensor) {
if (tensor == nullptr) { auto *t = transformer_engine::convertNVTETensor(tensor);
if (t == nullptr) {
NVTE_ERROR("Invalid tensor"); NVTE_ERROR("Invalid tensor");
} }
// Determine tensor shape depending on tensor format // Determine tensor shape depending on tensor format
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor); const std::vector<size_t> &shape = t->shape();
std::vector<size_t> shape = t.shape();
return nvte_make_shape(shape.data(), shape.size()); return nvte_make_shape(shape.data(), shape.size());
} }
NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) { NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) {
if (tensor == nullptr) { auto *t = transformer_engine::convertNVTETensor(tensor);
if (t == nullptr) {
NVTE_ERROR("Invalid tensor"); NVTE_ERROR("Invalid tensor");
} }
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor); const std::vector<size_t> &shape = t->columnwise_data.shape;
return nvte_make_shape(t.columnwise_data.shape.data(), t.columnwise_data.shape.size()); return nvte_make_shape(shape.data(), shape.size());
} }
size_t nvte_tensor_ndims(const NVTETensor tensor) { return nvte_tensor_shape(tensor).ndim; } size_t nvte_tensor_ndims(const NVTETensor tensor) { return nvte_tensor_shape(tensor).ndim; }
...@@ -265,82 +385,82 @@ size_t nvte_tensor_numel(const NVTETensor tensor) { ...@@ -265,82 +385,82 @@ size_t nvte_tensor_numel(const NVTETensor tensor) {
} }
size_t nvte_tensor_element_size(const NVTETensor tensor) { size_t nvte_tensor_element_size(const NVTETensor tensor) {
if (tensor == nullptr) return sizeof(float); auto *t = transformer_engine::convertNVTETensor(tensor);
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor); if (t == nullptr) return sizeof(float);
return transformer_engine::typeToSize(t.dtype()); return transformer_engine::typeToSize(t->dtype());
} }
void *nvte_tensor_data(const NVTETensor tensor) { void *nvte_tensor_data(const NVTETensor tensor) {
if (tensor == nullptr) return nullptr; auto *t = transformer_engine::convertNVTETensor(tensor);
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor); if (t == nullptr) return nullptr;
return t.data.dptr; return t->data.dptr;
} }
void *nvte_tensor_columnwise_data(const NVTETensor tensor) { void *nvte_tensor_columnwise_data(const NVTETensor tensor) {
if (tensor == nullptr) return nullptr; auto *t = transformer_engine::convertNVTETensor(tensor);
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor); if (t == nullptr) return nullptr;
return t.columnwise_data.dptr; return t->columnwise_data.dptr;
} }
float *nvte_tensor_amax(const NVTETensor tensor) { float *nvte_tensor_amax(const NVTETensor tensor) {
if (tensor == nullptr) return nullptr; auto *t = transformer_engine::convertNVTETensor(tensor);
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor); if (t == nullptr) return nullptr;
NVTE_CHECK(t.amax.dtype == transformer_engine::DType::kFloat32, NVTE_CHECK(t->amax.dtype == transformer_engine::DType::kFloat32,
"Tensor's amax must have Float32 type!"); "Tensor's amax must have Float32 type!");
return reinterpret_cast<float *>(t.amax.dptr); return reinterpret_cast<float *>(t->amax.dptr);
} }
float *nvte_tensor_scale(const NVTETensor tensor) { float *nvte_tensor_scale(const NVTETensor tensor) {
if (tensor == nullptr) return nullptr; auto *t = transformer_engine::convertNVTETensor(tensor);
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor); if (t == nullptr) return nullptr;
NVTE_CHECK(t.scale.dtype == transformer_engine::DType::kFloat32, NVTE_CHECK(t->scale.dtype == transformer_engine::DType::kFloat32,
"Tensor's scale must have Float32 type!"); "Tensor's scale must have Float32 type!");
return reinterpret_cast<float *>(t.scale.dptr); return reinterpret_cast<float *>(t->scale.dptr);
} }
float *nvte_tensor_scale_inv(const NVTETensor tensor) { float *nvte_tensor_scale_inv(const NVTETensor tensor) {
if (tensor == nullptr) return nullptr; auto *t = transformer_engine::convertNVTETensor(tensor);
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor); if (t == nullptr) return nullptr;
return reinterpret_cast<float *>(t.scale_inv.dptr); return reinterpret_cast<float *>(t->scale_inv.dptr);
} }
void *nvte_tensor_columnwise_scale_inv(const NVTETensor tensor) { void *nvte_tensor_columnwise_scale_inv(const NVTETensor tensor) {
if (tensor == nullptr) return nullptr; auto *t = transformer_engine::convertNVTETensor(tensor);
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor); if (t == nullptr) return nullptr;
return t.columnwise_scale_inv.dptr; return t->columnwise_scale_inv.dptr;
} }
NVTEShape nvte_tensor_scale_inv_shape(const NVTETensor tensor) { NVTEShape nvte_tensor_scale_inv_shape(const NVTETensor tensor) {
if (tensor == nullptr) { auto *t = transformer_engine::convertNVTETensor(tensor);
if (t == nullptr) {
return nvte_make_shape(nullptr, 0); return nvte_make_shape(nullptr, 0);
} }
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor); return nvte_make_shape(t->scale_inv.shape.data(), t->scale_inv.shape.size());
return nvte_make_shape(t.scale_inv.shape.data(), t.scale_inv.shape.size());
} }
void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name, void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name,
const NVTEBasicTensor *param) { const NVTEBasicTensor *param) {
NVTE_CHECK(tensor != nullptr, "Tensor pointer can't be NULL."); NVTE_CHECK(tensor != nullptr, "Tensor pointer can't be NULL.");
NVTE_CHECK(*tensor != nullptr, "Tensor is not allocated."); auto *t = transformer_engine::convertNVTETensor(*tensor);
auto &t = *reinterpret_cast<transformer_engine::Tensor *>(*tensor); NVTE_CHECK(t != nullptr, "Tensor is not allocated.");
switch (param_name) { switch (param_name) {
case kNVTERowwiseData: case kNVTERowwiseData:
t.data = *param; t->data = *param;
break; break;
case kNVTEColumnwiseData: case kNVTEColumnwiseData:
t.columnwise_data = *param; t->columnwise_data = *param;
break; break;
case kNVTEScale: case kNVTEScale:
t.scale = *param; t->scale = *param;
break; break;
case kNVTEAmax: case kNVTEAmax:
t.amax = *param; t->amax = *param;
break; break;
case kNVTERowwiseScaleInv: case kNVTERowwiseScaleInv:
t.scale_inv = *param; t->scale_inv = *param;
break; break;
case kNVTEColumnwiseScaleInv: case kNVTEColumnwiseScaleInv:
t.columnwise_scale_inv = *param; t->columnwise_scale_inv = *param;
break; break;
default: default:
NVTE_ERROR("Unknown tensor parameter!"); NVTE_ERROR("Unknown tensor parameter!");
...@@ -351,7 +471,7 @@ NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam p ...@@ -351,7 +471,7 @@ NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam p
if (tensor == nullptr) { if (tensor == nullptr) {
return {nullptr, kNVTEFloat32, nvte_make_shape(nullptr, 0)}; return {nullptr, kNVTEFloat32, nvte_make_shape(nullptr, 0)};
} }
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor); const auto &t = *transformer_engine::convertNVTETensorCheck(tensor);
switch (param_name) { switch (param_name) {
case kNVTERowwiseData: case kNVTERowwiseData:
return t.data; return t.data;
...@@ -371,25 +491,27 @@ NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam p ...@@ -371,25 +491,27 @@ NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam p
} }
NVTEScalingMode nvte_tensor_scaling_mode(const NVTETensor tensor) { NVTEScalingMode nvte_tensor_scaling_mode(const NVTETensor tensor) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor); if (tensor == nullptr) {
return NVTE_DELAYED_TENSOR_SCALING;
}
const auto &t = *transformer_engine::convertNVTETensorCheck(tensor);
return t.scaling_mode; return t.scaling_mode;
} }
void nvte_tensor_pack_create(NVTETensorPack *pack) { void nvte_tensor_pack_create(NVTETensorPack *pack) {
for (int i = 0; i < pack->MAX_SIZE; i++) { for (int i = 0; i < pack->MAX_SIZE; i++) {
pack->tensors[i] = reinterpret_cast<NVTETensor>(new transformer_engine::Tensor); pack->tensors[i] =
transformer_engine::TensorAllocator::instance().Allocate(NVTE_DELAYED_TENSOR_SCALING);
} }
} }
void nvte_tensor_pack_destroy(NVTETensorPack *pack) { void nvte_tensor_pack_destroy(NVTETensorPack *pack) {
for (int i = 0; i < pack->MAX_SIZE; i++) { transformer_engine::TensorAllocator::instance().Free(pack->tensors, pack->MAX_SIZE);
auto *t = reinterpret_cast<transformer_engine::Tensor *>(pack->tensors[i]);
delete t;
}
} }
void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream) { void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor); if (tensor == nullptr) return;
const auto &t = *transformer_engine::convertNVTETensorCheck(tensor);
// Zero out tensor data if allocated // Zero out tensor data if allocated
if (t.data.dptr != nullptr) { if (t.data.dptr != nullptr) {
size_t size_in_bytes = nvte_tensor_element_size(tensor) * nvte_tensor_numel(tensor); size_t size_in_bytes = nvte_tensor_element_size(tensor) * nvte_tensor_numel(tensor);
......
...@@ -348,15 +348,15 @@ void nvte_cast_transpose(const NVTETensor input, NVTETensor output, cudaStream_t ...@@ -348,15 +348,15 @@ void nvte_cast_transpose(const NVTETensor input, NVTETensor output, cudaStream_t
NVTE_API_CALL(nvte_cast_transpose); NVTE_API_CALL(nvte_cast_transpose);
using namespace transformer_engine; using namespace transformer_engine;
auto noop = Tensor(); auto noop = Tensor();
transformer_engine::detail::cast_transpose(*reinterpret_cast<const Tensor *>(input), noop, transformer_engine::detail::cast_transpose(*convertNVTETensorCheck(input), noop,
reinterpret_cast<Tensor *>(output), stream); convertNVTETensor(output), stream);
} }
void nvte_cast_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output, void nvte_cast_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_with_noop); NVTE_API_CALL(nvte_cast_transpose_with_noop);
using namespace transformer_engine; using namespace transformer_engine;
transformer_engine::detail::cast_transpose(*reinterpret_cast<const Tensor *>(input), transformer_engine::detail::cast_transpose(*convertNVTETensorCheck(input),
*reinterpret_cast<const Tensor *>(noop), *convertNVTETensorCheck(noop),
reinterpret_cast<Tensor *>(output), stream); convertNVTETensor(output), stream);
} }
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "../util/string.h" #include "../util/string.h"
#include "../utils.cuh" #include "../utils.cuh"
#include "cast_transpose.h" #include "cast_transpose.h"
#include "common/common.h"
namespace transformer_engine { namespace transformer_engine {
...@@ -1267,9 +1268,8 @@ void nvte_cast_transpose_dbias(const NVTETensor input, NVTETensor output, NVTETe ...@@ -1267,9 +1268,8 @@ void nvte_cast_transpose_dbias(const NVTETensor input, NVTETensor output, NVTETe
constexpr const NVTETensor activation_input = nullptr; constexpr const NVTETensor activation_input = nullptr;
cast_transpose_fused<IS_DBIAS, IS_DACT, IS_ACT, ComputeType, Empty, nullptr>( cast_transpose_fused<IS_DBIAS, IS_DACT, IS_ACT, ComputeType, Empty, nullptr>(
*reinterpret_cast<const Tensor *>(input), reinterpret_cast<const Tensor *>(activation_input), *convertNVTETensorCheck(input), convertNVTETensor(activation_input),
reinterpret_cast<Tensor *>(output), reinterpret_cast<Tensor *>(dbias), convertNVTETensor(output), convertNVTETensor(dbias), convertNVTETensor(workspace), stream);
reinterpret_cast<Tensor *>(workspace), stream);
} }
void nvte_cast_transpose_dbias_dgelu(const NVTETensor input, const NVTETensor act_input, void nvte_cast_transpose_dbias_dgelu(const NVTETensor input, const NVTETensor act_input,
...@@ -1284,9 +1284,9 @@ void nvte_cast_transpose_dbias_dgelu(const NVTETensor input, const NVTETensor ac ...@@ -1284,9 +1284,9 @@ void nvte_cast_transpose_dbias_dgelu(const NVTETensor input, const NVTETensor ac
constexpr bool IS_ACT = false; constexpr bool IS_ACT = false;
cast_transpose_fused<IS_DBIAS, IS_DACT, IS_ACT, ComputeType, Empty, dgelu<fp32, fp32>>( cast_transpose_fused<IS_DBIAS, IS_DACT, IS_ACT, ComputeType, Empty, dgelu<fp32, fp32>>(
*reinterpret_cast<const Tensor *>(input), reinterpret_cast<const Tensor *>(act_input), *convertNVTETensorCheck(input), convertNVTETensorCheck(act_input),
reinterpret_cast<Tensor *>(output), reinterpret_cast<Tensor *>(dbias), convertNVTETensorCheck(output), convertNVTETensorCheck(dbias), convertNVTETensor(workspace),
reinterpret_cast<Tensor *>(workspace), stream); stream);
} }
void nvte_cast_transpose_dbias_dsilu(const NVTETensor input, const NVTETensor silu_input, void nvte_cast_transpose_dbias_dsilu(const NVTETensor input, const NVTETensor silu_input,
...@@ -1301,9 +1301,9 @@ void nvte_cast_transpose_dbias_dsilu(const NVTETensor input, const NVTETensor si ...@@ -1301,9 +1301,9 @@ void nvte_cast_transpose_dbias_dsilu(const NVTETensor input, const NVTETensor si
constexpr bool IS_ACT = false; constexpr bool IS_ACT = false;
cast_transpose_fused<IS_DBIAS, IS_DACT, IS_ACT, ComputeType, Empty, dsilu<fp32, fp32>>( cast_transpose_fused<IS_DBIAS, IS_DACT, IS_ACT, ComputeType, Empty, dsilu<fp32, fp32>>(
*reinterpret_cast<const Tensor *>(input), reinterpret_cast<const Tensor *>(silu_input), *convertNVTETensorCheck(input), convertNVTETensorCheck(silu_input),
reinterpret_cast<Tensor *>(output), reinterpret_cast<Tensor *>(dbias), convertNVTETensorCheck(output), convertNVTETensorCheck(dbias), convertNVTETensor(workspace),
reinterpret_cast<Tensor *>(workspace), stream); stream);
} }
void nvte_cast_transpose_dbias_drelu(const NVTETensor input, const NVTETensor relu_input, void nvte_cast_transpose_dbias_drelu(const NVTETensor input, const NVTETensor relu_input,
...@@ -1318,9 +1318,9 @@ void nvte_cast_transpose_dbias_drelu(const NVTETensor input, const NVTETensor re ...@@ -1318,9 +1318,9 @@ void nvte_cast_transpose_dbias_drelu(const NVTETensor input, const NVTETensor re
constexpr bool IS_ACT = false; constexpr bool IS_ACT = false;
cast_transpose_fused<IS_DBIAS, IS_DACT, IS_ACT, ComputeType, Empty, drelu<fp32, fp32>>( cast_transpose_fused<IS_DBIAS, IS_DACT, IS_ACT, ComputeType, Empty, drelu<fp32, fp32>>(
*reinterpret_cast<const Tensor *>(input), reinterpret_cast<const Tensor *>(relu_input), *convertNVTETensorCheck(input), convertNVTETensorCheck(relu_input),
reinterpret_cast<Tensor *>(output), reinterpret_cast<Tensor *>(dbias), convertNVTETensorCheck(output), convertNVTETensorCheck(dbias), convertNVTETensor(workspace),
reinterpret_cast<Tensor *>(workspace), stream); stream);
} }
void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input, const NVTETensor srelu_input, void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input, const NVTETensor srelu_input,
...@@ -1335,9 +1335,9 @@ void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input, const NVTETensor s ...@@ -1335,9 +1335,9 @@ void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input, const NVTETensor s
constexpr bool IS_ACT = false; constexpr bool IS_ACT = false;
cast_transpose_fused<IS_DBIAS, IS_DACT, IS_ACT, ComputeType, Empty, dsrelu<fp32, fp32>>( cast_transpose_fused<IS_DBIAS, IS_DACT, IS_ACT, ComputeType, Empty, dsrelu<fp32, fp32>>(
*reinterpret_cast<const Tensor *>(input), reinterpret_cast<const Tensor *>(srelu_input), *convertNVTETensorCheck(input), convertNVTETensorCheck(srelu_input),
reinterpret_cast<Tensor *>(output), reinterpret_cast<Tensor *>(dbias), convertNVTETensorCheck(output), convertNVTETensorCheck(dbias), convertNVTETensor(workspace),
reinterpret_cast<Tensor *>(workspace), stream); stream);
} }
void nvte_cast_transpose_dbias_dqgelu(const NVTETensor input, const NVTETensor qgelu_input, void nvte_cast_transpose_dbias_dqgelu(const NVTETensor input, const NVTETensor qgelu_input,
...@@ -1352,9 +1352,9 @@ void nvte_cast_transpose_dbias_dqgelu(const NVTETensor input, const NVTETensor q ...@@ -1352,9 +1352,9 @@ void nvte_cast_transpose_dbias_dqgelu(const NVTETensor input, const NVTETensor q
constexpr bool IS_ACT = false; constexpr bool IS_ACT = false;
cast_transpose_fused<IS_DBIAS, IS_DACT, IS_ACT, ComputeType, Empty, dqgelu<fp32, fp32>>( cast_transpose_fused<IS_DBIAS, IS_DACT, IS_ACT, ComputeType, Empty, dqgelu<fp32, fp32>>(
*reinterpret_cast<const Tensor *>(input), reinterpret_cast<const Tensor *>(qgelu_input), *convertNVTETensorCheck(input), convertNVTETensorCheck(qgelu_input),
reinterpret_cast<Tensor *>(output), reinterpret_cast<Tensor *>(dbias), convertNVTETensorCheck(output), convertNVTETensorCheck(dbias), convertNVTETensor(workspace),
reinterpret_cast<Tensor *>(workspace), stream); stream);
} }
void nvte_dgeglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input, void nvte_dgeglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input,
...@@ -1364,8 +1364,8 @@ void nvte_dgeglu_cast_transpose(const NVTETensor input, const NVTETensor gated_a ...@@ -1364,8 +1364,8 @@ void nvte_dgeglu_cast_transpose(const NVTETensor input, const NVTETensor gated_a
using namespace transformer_engine::detail; using namespace transformer_engine::detail;
dgated_act_cast_transpose<ComputeType, Empty, dgelu<fp32, fp32>, gelu<fp32, fp32>>( dgated_act_cast_transpose<ComputeType, Empty, dgelu<fp32, fp32>, gelu<fp32, fp32>>(
*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(gated_act_input), *convertNVTETensorCheck(input), *convertNVTETensorCheck(gated_act_input),
reinterpret_cast<Tensor *>(output), stream); convertNVTETensorCheck(output), stream);
} }
void nvte_dswiglu_cast_transpose(const NVTETensor input, const NVTETensor swiglu_input, void nvte_dswiglu_cast_transpose(const NVTETensor input, const NVTETensor swiglu_input,
...@@ -1375,8 +1375,8 @@ void nvte_dswiglu_cast_transpose(const NVTETensor input, const NVTETensor swiglu ...@@ -1375,8 +1375,8 @@ void nvte_dswiglu_cast_transpose(const NVTETensor input, const NVTETensor swiglu
using namespace transformer_engine::detail; using namespace transformer_engine::detail;
dgated_act_cast_transpose<ComputeType, Empty, dsilu<fp32, fp32>, silu<fp32, fp32>>( dgated_act_cast_transpose<ComputeType, Empty, dsilu<fp32, fp32>, silu<fp32, fp32>>(
*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(swiglu_input), *convertNVTETensorCheck(input), *convertNVTETensorCheck(swiglu_input),
reinterpret_cast<Tensor *>(output), stream); convertNVTETensorCheck(output), stream);
} }
void nvte_dreglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input, void nvte_dreglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input,
...@@ -1386,8 +1386,8 @@ void nvte_dreglu_cast_transpose(const NVTETensor input, const NVTETensor gated_a ...@@ -1386,8 +1386,8 @@ void nvte_dreglu_cast_transpose(const NVTETensor input, const NVTETensor gated_a
using namespace transformer_engine::detail; using namespace transformer_engine::detail;
dgated_act_cast_transpose<ComputeType, Empty, drelu<fp32, fp32>, relu<fp32, fp32>>( dgated_act_cast_transpose<ComputeType, Empty, drelu<fp32, fp32>, relu<fp32, fp32>>(
*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(gated_act_input), *convertNVTETensorCheck(input), *convertNVTETensorCheck(gated_act_input),
reinterpret_cast<Tensor *>(output), stream); convertNVTETensorCheck(output), stream);
} }
void nvte_dsreglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input, void nvte_dsreglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input,
...@@ -1397,8 +1397,8 @@ void nvte_dsreglu_cast_transpose(const NVTETensor input, const NVTETensor gated_ ...@@ -1397,8 +1397,8 @@ void nvte_dsreglu_cast_transpose(const NVTETensor input, const NVTETensor gated_
using namespace transformer_engine::detail; using namespace transformer_engine::detail;
dgated_act_cast_transpose<ComputeType, Empty, dsrelu<fp32, fp32>, srelu<fp32, fp32>>( dgated_act_cast_transpose<ComputeType, Empty, dsrelu<fp32, fp32>, srelu<fp32, fp32>>(
*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(gated_act_input), *convertNVTETensorCheck(input), *convertNVTETensorCheck(gated_act_input),
reinterpret_cast<Tensor *>(output), stream); convertNVTETensorCheck(output), stream);
} }
void nvte_dqgeglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input, void nvte_dqgeglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input,
...@@ -1408,6 +1408,6 @@ void nvte_dqgeglu_cast_transpose(const NVTETensor input, const NVTETensor gated_ ...@@ -1408,6 +1408,6 @@ void nvte_dqgeglu_cast_transpose(const NVTETensor input, const NVTETensor gated_
using namespace transformer_engine::detail; using namespace transformer_engine::detail;
dgated_act_cast_transpose<ComputeType, Empty, dqgelu<fp32, fp32>, qgelu<fp32, fp32>>( dgated_act_cast_transpose<ComputeType, Empty, dqgelu<fp32, fp32>, qgelu<fp32, fp32>>(
*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(gated_act_input), *convertNVTETensorCheck(input), *convertNVTETensorCheck(gated_act_input),
reinterpret_cast<Tensor *>(output), stream); convertNVTETensorCheck(output), stream);
} }
...@@ -334,8 +334,8 @@ void nvte_multi_cast_transpose(size_t num_tensors, const NVTETensor* input_list, ...@@ -334,8 +334,8 @@ void nvte_multi_cast_transpose(size_t num_tensors, const NVTETensor* input_list,
using namespace transformer_engine; using namespace transformer_engine;
std::vector<Tensor*> input_list_, output_list_; std::vector<Tensor*> input_list_, output_list_;
for (size_t i = 0; i < num_tensors; ++i) { for (size_t i = 0; i < num_tensors; ++i) {
input_list_.push_back(reinterpret_cast<Tensor*>(const_cast<NVTETensor&>(input_list[i]))); input_list_.push_back(convertNVTETensorCheck(input_list[i]));
output_list_.push_back(reinterpret_cast<Tensor*>(output_list[i])); output_list_.push_back(convertNVTETensorCheck(output_list[i]));
} }
multi_cast_transpose(input_list_, output_list_, stream); multi_cast_transpose(input_list_, output_list_, stream);
} }
...@@ -288,14 +288,13 @@ void nvte_transpose(const NVTETensor input, NVTETensor output, cudaStream_t stre ...@@ -288,14 +288,13 @@ void nvte_transpose(const NVTETensor input, NVTETensor output, cudaStream_t stre
NVTE_API_CALL(nvte_transpose); NVTE_API_CALL(nvte_transpose);
using namespace transformer_engine; using namespace transformer_engine;
auto noop = Tensor(); auto noop = Tensor();
transpose(*reinterpret_cast<const Tensor *>(input), noop, reinterpret_cast<Tensor *>(output), transpose(*convertNVTETensorCheck(input), noop, convertNVTETensor(output), stream);
stream);
} }
void nvte_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output, void nvte_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_transpose_with_noop); NVTE_API_CALL(nvte_transpose_with_noop);
using namespace transformer_engine; using namespace transformer_engine;
transpose(*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(noop), transpose(*convertNVTETensorCheck(input), *convertNVTETensorCheck(noop),
reinterpret_cast<Tensor *>(output), stream); convertNVTETensor(output), stream);
} }
...@@ -495,7 +495,6 @@ void nvte_fp8_transpose_dbias(const NVTETensor input, NVTETensor transposed_outp ...@@ -495,7 +495,6 @@ void nvte_fp8_transpose_dbias(const NVTETensor input, NVTETensor transposed_outp
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_fp8_transpose_dbias); NVTE_API_CALL(nvte_fp8_transpose_dbias);
using namespace transformer_engine; using namespace transformer_engine;
fp8_transpose_dbias( fp8_transpose_dbias(*convertNVTETensorCheck(input), convertNVTETensor(transposed_output),
*reinterpret_cast<const Tensor *>(input), reinterpret_cast<Tensor *>(transposed_output), convertNVTETensor(dbias), convertNVTETensor(workspace), stream);
reinterpret_cast<Tensor *>(dbias), reinterpret_cast<Tensor *>(workspace), stream);
} }
...@@ -154,6 +154,5 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activat ...@@ -154,6 +154,5 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activat
void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_dequantize); NVTE_API_CALL(nvte_dequantize);
using namespace transformer_engine; using namespace transformer_engine;
detail::dequantize_helper(*reinterpret_cast<const Tensor *>(input), detail::dequantize_helper(*convertNVTETensorCheck(input), convertNVTETensorCheck(output), stream);
reinterpret_cast<Tensor *>(output), stream);
} }
...@@ -1057,10 +1057,9 @@ void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input, ...@@ -1057,10 +1057,9 @@ void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input,
cudaStream_t stream) { cudaStream_t stream) {
using namespace gated_kernels; using namespace gated_kernels;
Tensor grad_empty_tensor; Tensor grad_empty_tensor;
const Tensor &grad_tensor = const Tensor &grad_tensor = IS_DGATED ? *(convertNVTETensorCheck(grad)) : grad_empty_tensor;
IS_DGATED ? *(reinterpret_cast<const Tensor *>(grad)) : grad_empty_tensor; const Tensor gated_input_tensor = *convertNVTETensorCheck(gated_input);
const Tensor gated_input_tensor = *reinterpret_cast<const Tensor *>(gated_input); Tensor *output_tensor = convertNVTETensorCheck(output);
Tensor *output_tensor = reinterpret_cast<Tensor *>(output);
if (is_supported_by_CC_100()) { if (is_supported_by_CC_100()) {
quantize_gated<IS_DGATED, ParamOP, ActOP, DActOP>(grad_tensor, gated_input_tensor, quantize_gated<IS_DGATED, ParamOP, ActOP, DActOP>(grad_tensor, gated_input_tensor,
......
...@@ -1222,23 +1222,23 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o ...@@ -1222,23 +1222,23 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
const Tensor *activation_input_tensor; const Tensor *activation_input_tensor;
if constexpr (IS_DBIAS || IS_DACT) { if constexpr (IS_DBIAS || IS_DACT) {
// backward - input is incoming gradient // backward - input is incoming gradient
input_tensor = reinterpret_cast<const Tensor *>(grad); input_tensor = convertNVTETensorCheck(grad);
activation_input_tensor = reinterpret_cast<const Tensor *>(input); activation_input_tensor = convertNVTETensor(input);
} else { } else {
// forward = input is activation input // forward = input is activation input
input_tensor = reinterpret_cast<const Tensor *>(input); input_tensor = convertNVTETensorCheck(input);
activation_input_tensor = nullptr; activation_input_tensor = nullptr;
} }
auto output_tensor = reinterpret_cast<Tensor *>(output); auto output_tensor = convertNVTETensorCheck(output);
auto dbias_tensor = reinterpret_cast<Tensor *>(dbias); auto dbias_tensor = convertNVTETensor(dbias);
auto workspace_tensor = reinterpret_cast<Tensor *>(workspace); auto workspace_tensor = convertNVTETensor(workspace);
const QuantizationConfig *quant_config_cpp = const QuantizationConfig *quant_config_cpp =
reinterpret_cast<const QuantizationConfig *>(quant_config); reinterpret_cast<const QuantizationConfig *>(quant_config);
// extract noop tensor from quant_config_cpp if it's not null // extract noop tensor from quant_config_cpp if it's not null
const NVTETensor noop = quant_config_cpp ? quant_config_cpp->noop_tensor : nullptr; const NVTETensor noop = quant_config_cpp ? quant_config_cpp->noop_tensor : nullptr;
const auto noop_tensor = noop != nullptr ? *(reinterpret_cast<const Tensor *>(noop)) : Tensor(); const auto noop_tensor = noop != nullptr ? *(convertNVTETensorCheck(noop)) : Tensor();
switch (output_tensor->scaling_mode) { switch (output_tensor->scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: { case NVTE_DELAYED_TENSOR_SCALING: {
......
...@@ -211,8 +211,8 @@ void nvte_multi_padding(size_t num_tensors, const NVTETensor* input_list, NVTETe ...@@ -211,8 +211,8 @@ void nvte_multi_padding(size_t num_tensors, const NVTETensor* input_list, NVTETe
std::vector<Tensor*> input_list_, output_list_; std::vector<Tensor*> input_list_, output_list_;
std::vector<int> padded_num_rows_list_; std::vector<int> padded_num_rows_list_;
for (size_t i = 0; i < num_tensors; ++i) { for (size_t i = 0; i < num_tensors; ++i) {
input_list_.push_back(reinterpret_cast<Tensor*>(const_cast<NVTETensor&>(input_list[i]))); input_list_.push_back(convertNVTETensorCheck(input_list[i]));
output_list_.push_back(reinterpret_cast<Tensor*>(output_list[i])); output_list_.push_back(convertNVTETensorCheck(output_list[i]));
padded_num_rows_list_.push_back(padded_num_rows_list[i]); padded_num_rows_list_.push_back(padded_num_rows_list[i]);
} }
multi_padding(input_list_, output_list_, padded_num_rows_list_, stream); multi_padding(input_list_, output_list_, padded_num_rows_list_, stream);
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "extensions.h" #include "extensions.h"
#include "transformer_engine/fused_attn.h" #include "transformer_engine/fused_attn.h"
#include "transformer_engine/transformer_engine.h"
namespace transformer_engine { namespace transformer_engine {
namespace jax { namespace jax {
...@@ -40,33 +41,46 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t ...@@ -40,33 +41,46 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t
// all backends need softmax but expect different shapes/dtypes // all backends need softmax but expect different shapes/dtypes
// start with the max512 sequence length softmax shape/dtype and correct later // start with the max512 sequence length softmax shape/dtype and correct later
tensor_pack->size = 1; tensor_pack->size = 1;
Tensor *softmax_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[0]); NVTETensor &softmax_aux = tensor_pack->tensors[0];
softmax_aux->data.dptr = softmax_buf; NVTEBasicTensor softmax_aux_data;
softmax_aux->data.shape = softmax_aux_data.data_ptr = softmax_buf;
std::vector<size_t>{input_batch, attn_heads, q_max_seqlen, kv_max_seqlen}; softmax_aux_data.shape.ndim = 4;
softmax_aux->data.dtype = dtype; softmax_aux_data.shape.data[0] = input_batch;
softmax_aux_data.shape.data[1] = attn_heads;
softmax_aux_data.shape.data[2] = q_max_seqlen;
softmax_aux_data.shape.data[3] = kv_max_seqlen;
softmax_aux_data.dtype = static_cast<NVTEDType>(dtype);
// arbitrary sequence length backend needs the RNG state and a different shape/dtype softmax // arbitrary sequence length backend needs the RNG state and a different shape/dtype softmax
if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
tensor_pack->size = 2; tensor_pack->size = 2;
Tensor *rng_state_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[1]); NVTETensor &rng_state_aux = tensor_pack->tensors[1];
rng_state_aux->data.dptr = rng_state_buf; NVTEBasicTensor rng_state_aux_data;
rng_state_aux->data.shape = std::vector<size_t>{2}; rng_state_aux_data.data_ptr = rng_state_buf;
rng_state_aux->data.dtype = DType::kInt64; rng_state_aux_data.shape = {};
rng_state_aux_data.shape.ndim = 2;
rng_state_aux_data.dtype = static_cast<NVTEDType>(DType::kInt64);
nvte_set_tensor_param(&rng_state_aux, kNVTERowwiseData, &rng_state_aux_data);
// correct softmax shape/dtype // correct softmax shape/dtype
softmax_aux->data.shape.at(3) = 1; // {B,H,Qs,Ks} -> {B,H,Qs,1} softmax_aux_data.shape.data[3] = 1; // {B,H,Qs,Ks} -> {B,H,Qs,1}
softmax_aux->data.dtype = DType::kFloat32; softmax_aux_data.dtype = static_cast<NVTEDType>(DType::kFloat32);
// include bias if enabled // include bias if enabled
if (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS && bias_type != NVTE_Bias_Type::NVTE_ALIBI) { if (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS && bias_type != NVTE_Bias_Type::NVTE_ALIBI) {
tensor_pack->size = 3; tensor_pack->size = 3;
Tensor *bias_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[2]); NVTETensor &bias_aux = tensor_pack->tensors[2];
bias_aux->data.dptr = bias_buf; NVTEBasicTensor bias_aux_data;
bias_aux->data.shape = bias_aux_data.data_ptr = bias_buf;
std::vector<size_t>{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; bias_aux_data.shape.ndim = 4;
bias_aux->data.dtype = dtype; bias_aux_data.shape.data[0] = bias_batch;
bias_aux_data.shape.data[1] = bias_heads;
bias_aux_data.shape.data[2] = q_max_seqlen;
bias_aux_data.shape.data[3] = kv_max_seqlen;
bias_aux_data.dtype = static_cast<NVTEDType>(dtype);
nvte_set_tensor_param(&bias_aux, kNVTERowwiseData, &bias_aux_data);
} }
} }
nvte_set_tensor_param(&softmax_aux, kNVTERowwiseData, &softmax_aux_data);
} }
/* /*
...@@ -93,9 +107,11 @@ void PrepareFusedAttnBackwardAuxTensors(NVTETensorPack *tensor_pack, const size_ ...@@ -93,9 +107,11 @@ void PrepareFusedAttnBackwardAuxTensors(NVTETensorPack *tensor_pack, const size_
// correct softmax shape for max512 sequence length kernel // correct softmax shape for max512 sequence length kernel
if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
Tensor *softmax_aux = reinterpret_cast<Tensor *>(tensor_pack->tensors[0]); NVTEBasicTensor softmax_aux_data =
softmax_aux->data.shape.at(3) = kv_max_seqlen; // {B,H,Qs,1} -> {B,H,Qs,Ks} nvte_get_tensor_param(tensor_pack->tensors[0], kNVTERowwiseData);
softmax_aux->data.dtype = dtype; softmax_aux_data.shape.data[3] = kv_max_seqlen; // {B,H,Qs,1} -> {B,H,Qs,Ks}
softmax_aux_data.dtype = static_cast<NVTEDType>(dtype);
nvte_set_tensor_param(&(tensor_pack->tensors[0]), kNVTERowwiseData, &softmax_aux_data);
} }
} }
......
...@@ -20,6 +20,20 @@ std::vector<size_t> getTensorShape(at::Tensor t) { ...@@ -20,6 +20,20 @@ std::vector<size_t> getTensorShape(at::Tensor t) {
return shape; return shape;
} }
NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape) {
NVTEShape ret;
ret.ndim = torch_shape.size();
constexpr int max_dimensions = sizeof(ret.data) / sizeof(size_t);
NVTE_CHECK(ret.ndim < max_dimensions,
"Torch tensor has too many dimensions. Max supported: ", max_dimensions, " and got ",
ret.ndim, ".");
for (size_t i = 0; i < ret.ndim; ++i) {
const auto& v = torch_shape[i];
ret.data[i] = static_cast<size_t>(v);
}
return ret;
}
std::unique_ptr<Quantizer> convert_quantizer(py::handle quantizer) { std::unique_ptr<Quantizer> convert_quantizer(py::handle quantizer) {
init_extension(); init_extension();
if (quantizer.is_none()) { if (quantizer.is_none()) {
......
...@@ -351,6 +351,7 @@ std::vector<size_t> convertShape(const NVTEShape& shape); ...@@ -351,6 +351,7 @@ std::vector<size_t> convertShape(const NVTEShape& shape);
int roundup(const int value, const int multiple); int roundup(const int value, const int multiple);
NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape);
} // namespace transformer_engine::pytorch } // namespace transformer_engine::pytorch
namespace std { namespace std {
......
...@@ -9,8 +9,8 @@ ...@@ -9,8 +9,8 @@
#include <string> #include <string>
#include "common/common.h"
#include "extensions.h" #include "extensions.h"
#include "transformer_engine/transformer_engine.h"
namespace transformer_engine::pytorch { namespace transformer_engine::pytorch {
...@@ -34,30 +34,35 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor& amax_reductio ...@@ -34,30 +34,35 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor& amax_reductio
const std::string& amax_compute_algo, const std::string& amax_compute_algo,
DType fp8_dtype, float margin) { DType fp8_dtype, float margin) {
size_t num_tensors = amax_histories.size(); size_t num_tensors = amax_histories.size();
std::vector<Tensor> t_amax_histories(num_tensors); std::vector<NVTETensor> te_amax_histories;
std::vector<Tensor> t_scales(num_tensors); std::vector<NVTETensor> te_scales;
std::vector<NVTETensor> te_amax_histories(num_tensors); te_amax_histories.reserve(num_tensors);
std::vector<NVTETensor> te_scales(num_tensors); te_scales.reserve(num_tensors);
for (size_t i = 0; i < num_tensors; i++) { for (size_t i = 0; i < num_tensors; i++) {
t_amax_histories[i].data.dptr = amax_histories[i].data_ptr(); te_amax_histories.push_back(nvte_create_tensor(NVTE_DELAYED_TENSOR_SCALING));
auto amax_sizes = amax_histories[i].sizes().vec(); NVTETensor& amax_history = te_amax_histories.back();
std::vector<size_t> amax_shape{amax_sizes.begin(), amax_sizes.end()}; NVTEShape amax_shape = convertTorchShape(amax_histories[i].sizes());
t_amax_histories[i].data.shape = amax_shape; NVTEBasicTensor amax_history_data = {amax_histories[i].data_ptr(),
t_amax_histories[i].data.dtype = DType::kFloat32; static_cast<NVTEDType>(DType::kFloat32), amax_shape};
nvte_set_tensor_param(&amax_history, kNVTERowwiseData, &amax_history_data);
t_scales[i].data.dptr = scales[i].data_ptr();
auto scale_sizes = scales[i].sizes().vec(); te_scales.push_back(nvte_create_tensor(NVTE_DELAYED_TENSOR_SCALING));
std::vector<size_t> scale_shape{scale_sizes.begin(), scale_sizes.end()}; NVTETensor& scale = te_scales.back();
t_scales[i].data.shape = scale_shape; NVTEShape scale_shape = convertTorchShape(scales[i].sizes());
t_scales[i].data.dtype = DType::kFloat32; NVTEBasicTensor scale_data = {scales[i].data_ptr(), static_cast<NVTEDType>(DType::kFloat32),
scale_shape};
te_amax_histories[i] = reinterpret_cast<NVTETensor>(&t_amax_histories[i]); nvte_set_tensor_param(&scale, kNVTERowwiseData, &scale_data);
te_scales[i] = reinterpret_cast<NVTETensor>(&t_scales[i]);
} }
nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction(
makeTransformerEngineTensor(amax_reduction_buffer).data(), te_amax_histories, te_scales, makeTransformerEngineTensor(amax_reduction_buffer).data(), te_amax_histories, te_scales,
amax_compute_algo.c_str(), static_cast<NVTEDType>(fp8_dtype), margin, amax_compute_algo.c_str(), static_cast<NVTEDType>(fp8_dtype), margin,
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
for (auto& t : te_amax_histories) {
nvte_destroy_tensor(t);
}
for (auto& t : te_scales) {
nvte_destroy_tensor(t);
}
} }
} // namespace transformer_engine::pytorch } // namespace transformer_engine::pytorch
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment