Unverified Commit 14b53313 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[Common] NVTEGroupedTensor class and helpers (#2388)



* add grouped_tensor classes and helpers
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* rm non-contiguous option and dptrs
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* address comments + rework CheckIn/OutputGroupedTensor
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix for compilation
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* make first_dims/last_dims optional + data.shape 2d
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* added assertion
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* rs conflicts
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* add data.shape info
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* added logical shape field
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* compilation fix
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* fixed issues raised by greptile
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* return default dtype when grouped_tensor is empty
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* use has_data() for dim queries
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update comments
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* fix index bound
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* Update transformer_engine/common/transformer_engine.cpp
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* Update transformer_engine/common/transformer_engine.cpp
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* restore Tensor.has_data() + add experimental marks
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* restore Tensor::has_columnwise_data
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* cleanup
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
parent f1512b21
......@@ -101,6 +101,7 @@ struct SimpleTensor {
}
return acc;
}
bool has_data() const noexcept { return dptr != nullptr && numel() > 0; }
void clear() {
dptr = nullptr;
......@@ -154,9 +155,11 @@ struct Tensor {
return acc;
}
// TODO(Tim): Change this to use data.has_data()
bool has_data() const noexcept { return data.dptr != nullptr; }
// Check for size (not just pointer) for 0-dim or no token cases.
// TODO(Tim): Change this to use columnwise_data.has_data()
bool has_columnwise_data() const noexcept {
return columnwise_data.dptr != nullptr || columnwise_data.shape.size() != 0;
}
......@@ -281,6 +284,129 @@ struct Tensor {
}
};
struct GroupedTensor {
public:
/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */
/*
Grouped tensor is a collection of tensors with different shapes but the same dtype and scaling mode
Shape Representation:
- logical_shape: 2D shape representing the conceptual layouy, i.e. the shape when member tensors are flattened to 2D and stacked together (REQUIRED)
+ When all_same_shape(): [num_tensors * M, N] where each tensor is (M, N)
+ When varying_first_dim(): [~sum_of_first_dims, N] where N is common
+ When varying_last_dim(): [M, ~sum_of_last_dims] where M is common
+ When varying_both_dims(): [1, total_elements] (fully flattened)
- first_dims and last_dims are OPTIONAL (empty if dimension is uniform)
+ Empty first_dims: all tensors have the same first dimension
+ Empty last_dims: all tensors have the same last dimension
+ Both empty: all tensors have identical shapes
+ Both set: each tensor has unique shape (first_dims[i], last_dims[i])
Data Layout:
- ALL data fields are stored as 1D flattened arrays (data, columnwise_data, scale_inv, etc.)
- logical_shape provides the conceptual 2D interpretation
- All data is stored on device in contiguous layout
*/
SimpleTensor data;
SimpleTensor columnwise_data;
SimpleTensor scale_inv;
SimpleTensor columnwise_scale_inv;
SimpleTensor amax;
SimpleTensor columnwise_amax;
SimpleTensor scale; // for FP8-DS only
// Shape information (OPTIONAL - empty if dimension is uniform across all tensors)
// first_dims[i] = first dimension of tensor i (empty if all tensors have same first dim)
// last_dims[i] = last dimension of tensor i (empty if all tensors have same last dim)
SimpleTensor first_dims; // Device pointer to int64_t array of length num_tensors (or empty)
SimpleTensor last_dims; // Device pointer to int64_t array of length num_tensors (or empty)
// Offsets for indexing into contiguous 1D layout (OPTIONAL - not needed if all_same_shape())
// tensor_offsets[i] = element offset to start of tensor i (cumulative sum of numel for tensors 0..i-1)
// Usage: tensor_i_ptr = (char*)data.dptr + tensor_offsets[i] * element_size
// If empty and all_same_shape(): offset[i] = i * M * N (where M, N are common dimensions)
SimpleTensor tensor_offsets; // Device pointer to int64_t array of length num_tensors (or empty)
// Logical shape: conceptual 2D shape of the grouped data (REQUIRED)
// Represents how the 1D flattened data should be interpreted as 2D
// Always 2D with positive dimensions
NVTEShape logical_shape;
NVTEScalingMode scaling_mode;
size_t num_tensors;
NVTEGroupedTensor nvte_tensor;
GroupedTensor(NVTEScalingMode scaling_mode, size_t num_tensors)
: data(),
columnwise_data(),
scale_inv(),
columnwise_scale_inv(),
amax(),
columnwise_amax(),
scale(),
num_tensors(num_tensors),
first_dims(nullptr, {}, DType::kInt64),
last_dims(nullptr, {}, DType::kInt64),
tensor_offsets(nullptr, {}, DType::kInt64),
logical_shape(nvte_make_shape(nullptr, 0)),
scaling_mode(scaling_mode),
nvte_tensor(0) {}
explicit operator NVTEGroupedTensor() const noexcept { return nvte_tensor; }
bool has_data() const noexcept { return data.has_data(); }
bool has_columnwise_data() const noexcept { return columnwise_data.has_data(); }
bool all_same_first_dim() const noexcept { return !first_dims.has_data(); }
bool all_same_last_dim() const noexcept { return !last_dims.has_data(); }
bool all_same_shape() const noexcept { return !first_dims.has_data() && !last_dims.has_data(); }
bool varying_both_dims() const noexcept { return first_dims.has_data() && last_dims.has_data(); }
size_t get_common_first_dim() const {
NVTE_CHECK(all_same_first_dim(), "First dim varies across tensors");
NVTE_CHECK(logical_shape.ndim == 2, "Logical shape must be 2D");
if (all_same_shape()) {
// When both dims are uniform: logical_shape = [num_tensors * M, N]
return logical_shape.data[0] / num_tensors;
} else {
// When varying last dims but not first dim: logical_shape = [M, sum_of_last_dims]
return logical_shape.data[0];
}
}
size_t get_common_last_dim() const {
NVTE_CHECK(all_same_last_dim(), "Last dim varies across tensors");
NVTE_CHECK(logical_shape.ndim == 2, "Logical shape must be 2D");
// For both uniform and varying first dim cases: logical_shape[1] is the common last dim
return logical_shape.data[1];
}
DType dtype() const {
if (has_data()) return data.dtype;
if (has_columnwise_data()) return columnwise_data.dtype;
// Fallback, used e.g. in workspace or when allow_empty=true
return data.dtype;
}
void clear() {
data.clear();
columnwise_data.clear();
scale_inv.clear();
columnwise_scale_inv.clear();
amax.clear();
columnwise_amax.clear();
scale.clear();
first_dims.clear();
last_dims.clear();
tensor_offsets.clear();
logical_shape = nvte_make_shape(nullptr, 0);
num_tensors = 0;
scaling_mode = NVTE_DELAYED_TENSOR_SCALING;
nvte_tensor = 0;
}
};
struct QuantizationConfig {
bool force_pow_2_scales = false;
float amax_epsilon = 0.0f;
......@@ -779,6 +905,16 @@ std::vector<std::vector<Tensor *>> convert_tensor_array(NVTETensor **nvte_tensor
Tensor *convertNVTETensor(const NVTETensor tensor);
Tensor *convertNVTETensorCheck(const NVTETensor tensor);
GroupedTensor *convertNVTEGroupedTensor(const NVTEGroupedTensor tensor);
GroupedTensor *convertNVTEGroupedTensorCheck(const NVTEGroupedTensor tensor);
// Helper functions for GroupedTensor validation
void CheckGroupedTensorShapeArrays(const GroupedTensor &t, const std::string &name);
void CheckInputGroupedTensor(const GroupedTensor &t, const std::string &name);
void CheckOutputGroupedTensor(const GroupedTensor &t, const std::string &name,
bool allow_empty = false);
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_COMMON_H_
......@@ -393,6 +393,114 @@ int nvte_is_non_tn_fp8_gemm_supported();
*/
void nvte_memset(void *ptr, int value, size_t size_in_bytes, cudaStream_t stream);
/*! \brief TE Grouped Tensor type
*
* NVTEGroupedTensor is a collection of tensors with potentially different shapes
* but the same dtype and scaling mode. It does not own the memory it points to.
*/
typedef void *NVTEGroupedTensor;
/*! \enum NVTEGroupedTensorParam
* \brief Indicates the kind of the grouped tensor parameter to set/get.
*/
enum NVTEGroupedTensorParam {
kNVTEGroupedRowwiseData = 0, /*!< Data usable in rowwise manner */
kNVTEGroupedColumnwiseData = 1, /*!< Data usable in columnwise manner */
kNVTEGroupedScale = 2, /*!< Scale tensor */
kNVTEGroupedAmax = 3, /*!< Amax tensor */
kNVTEGroupedRowwiseScaleInv = 4, /*!< Scale inverse tensor for decoding Rowwise Data */
kNVTEGroupedColumnwiseScaleInv = 5, /*!< Scale inverse tensor for decoding Columnwise Data */
kNVTEGroupedColumnwiseAmax = 6, /*!< Columnwise Amax tensor */
kNVTEGroupedFirstDims = 7, /*!< First dimension sizes (device pointer to int64_t array) */
kNVTEGroupedLastDims = 8, /*!< Last dimension sizes (device pointer to int64_t array) */
kNVTEGroupedTensorOffsets =
9, /*!< Tensor offsets for contiguous layout (device pointer to int64_t array) */
kNVTENumGroupedTensorParams
};
/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */
/*! \brief Create a new TE grouped tensor.
*
* Create a new TE grouped tensor. Before use its parameters need to be set.
* TE grouped tensors are just wrappers on top of raw data and do not
* own memory.
*
* \param[in] scaling_mode Scaling mode of the grouped tensor.
* \param[in] num_tensors Number of tensors in the group (must be > 0).
* \param[in] logical_shape Logical 2D shape of the grouped data.
*
* \return A new TE grouped tensor.
*/
NVTEGroupedTensor nvte_create_grouped_tensor(NVTEScalingMode scaling_mode, size_t num_tensors,
NVTEShape logical_shape);
/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */
/*! \brief Destroy a TE grouped tensor.
*
* Since the TE grouped tensor does not own memory, the underlying
* data is not freed during this operation.
*
* \param[in] tensor Grouped tensor to be destroyed.
*/
void nvte_destroy_grouped_tensor(NVTEGroupedTensor tensor);
/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */
/*! \brief Set a parameter of the grouped tensor.
*
* \param[in/out] tensor Grouped tensor.
* \param[in] param_name The parameter to be set.
* \param[in] param The value to be set (NVTEBasicTensor).
*/
void nvte_set_grouped_tensor_param(NVTEGroupedTensor *tensor, NVTEGroupedTensorParam param_name,
const NVTEBasicTensor *param);
/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */
/*! \brief Get a value of the parameter of the grouped tensor.
*
* \param[in] tensor Grouped tensor.
* \param[in] param_name The parameter to be queried.
*
* \return NVTEBasicTensor containing the parameter data.
*/
NVTEBasicTensor nvte_get_grouped_tensor_param(const NVTEGroupedTensor tensor,
NVTEGroupedTensorParam param_name);
/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */
/*! \brief Get the number of tensors in a grouped tensor.
*
* \param[in] tensor Grouped tensor.
*
* \return Number of tensors in the group.
*/
size_t nvte_grouped_tensor_num_tensors(const NVTEGroupedTensor tensor);
/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */
/*! \brief Get a grouped tensor's data type.
*
* \param[in] tensor Grouped tensor.
*
* \return A data type of the grouped tensor.
*/
NVTEDType nvte_grouped_tensor_type(const NVTEGroupedTensor tensor);
/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */
/*! \brief Get a scaling mode of the grouped tensor.
*
* \param[in] tensor Grouped tensor.
*
* \return Scaling mode of the grouped tensor.
*/
NVTEScalingMode nvte_grouped_tensor_scaling_mode(const NVTEGroupedTensor tensor);
/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */
/*! \brief Get the logical shape of a grouped tensor.
*
* \param[in] tensor Grouped tensor.
*
* \return Logical 2D shape.
*/
NVTEShape nvte_get_grouped_tensor_logical_shape(const NVTEGroupedTensor tensor);
#ifdef __cplusplus
} // extern "C"
......
......@@ -273,6 +273,128 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
CheckScaleTensorShape(t, name);
}
void CheckGroupedTensorShapeArrays(const GroupedTensor &t, const std::string &name) {
NVTE_CHECK(t.num_tensors > 0, "Grouped tensor ", name, " has no tensors!");
// Helper lambda to validate shape arrays
// All three arrays are OPTIONAL:
// - first_dims: empty if all tensors have same first dimension
// - last_dims: empty if all tensors have same last dimension
// - tensor_offsets: empty if all tensors have same shape (offsets are predictable)
auto check_shape_array = [&](const SimpleTensor &arr, const char *arr_name) {
if (arr.has_data()) {
NVTE_CHECK(arr.shape.size() == 1, "Grouped tensor ", name, " ", arr_name, " must be 1D");
NVTE_CHECK(arr.dtype == DType::kInt64, "Grouped tensor ", name, " ", arr_name,
" must have dtype Int64");
NVTE_CHECK(arr.shape[0] == t.num_tensors, "Grouped tensor ", name, " ", arr_name, " size (",
arr.shape[0], ") must equal num_tensors (", t.num_tensors, ")");
}
};
// Validate shape arrays (all optional)
check_shape_array(t.first_dims, "first_dims");
check_shape_array(t.last_dims, "last_dims");
check_shape_array(t.tensor_offsets, "tensor_offsets");
// tensor_offsets is required if any dimension varies
// (i.e., required unless all_same_shape())
if (!t.all_same_shape()) {
NVTE_CHECK(
t.tensor_offsets.dptr != nullptr, "Grouped tensor ", name,
" must have tensor_offsets when any dimension varies (first_dims or last_dims is set)");
}
// Validate logical_shape
NVTE_CHECK(t.logical_shape.ndim == 2, "Grouped tensor ", name, " logical_shape must be 2D");
NVTE_CHECK(t.logical_shape.data[0] > 0 && t.logical_shape.data[1] > 0, "Grouped tensor ", name,
" logical_shape must have positive dimensions");
// Validate all data fields are 1D (flattened)
if (t.has_data()) {
NVTE_CHECK(t.data.shape.size() == 1, "Grouped tensor ", name, " data must be 1D");
}
if (t.has_columnwise_data()) {
NVTE_CHECK(t.columnwise_data.shape.size() == 1, "Grouped tensor ", name,
" columnwise_data must be 1D");
}
// Validate data size matches logical_shape
size_t expected_numel = t.logical_shape.data[0] * t.logical_shape.data[1];
if (t.has_data()) {
NVTE_CHECK(t.data.numel() == expected_numel, "Grouped tensor ", name, " data size (",
t.data.numel(), ") must match logical_shape size (", expected_numel, ")");
}
if (t.has_columnwise_data()) {
NVTE_CHECK(t.columnwise_data.numel() == expected_numel, "Grouped tensor ", name,
" columnwise_data size (", t.columnwise_data.numel(),
") must match logical_shape size (", expected_numel, ")");
}
}
// Helper function to check scale_inv for both input and output
static void CheckGroupedScaleInv(const GroupedTensor &t, const std::string &name, bool is_output) {
const char *tensor_type = is_output ? "output" : "input";
// Helper to check scale_inv for both rowwise and columnwise layouts
auto check_scales = [&](DType expected_dtype) {
if (t.has_data()) {
NVTE_CHECK(t.scale_inv.has_data(), tensor_type, " ", name,
" rowwise scale_inv must be allocated");
NVTE_CHECK(t.scale_inv.dtype == expected_dtype, tensor_type, " ", name,
" rowwise scale_inv has invalid dtype (expected ", to_string(expected_dtype),
", got ", to_string(t.scale_inv.dtype), ")");
}
if (t.has_columnwise_data()) {
NVTE_CHECK(t.columnwise_scale_inv.has_data(), tensor_type, " ", name,
" columnwise scale_inv must be allocated");
NVTE_CHECK(t.columnwise_scale_inv.dtype == expected_dtype, tensor_type, " ", name,
" columnwise scale_inv has invalid dtype (expected ", to_string(expected_dtype),
", got ", to_string(t.columnwise_scale_inv.dtype), ")");
}
};
// Determine expected dtype based on data type and scaling mode
if (is_fp8_dtype(t.dtype()) && is_tensor_scaling(t.scaling_mode)) {
check_scales(DType::kFloat32);
} else if (is_mxfp8_scaling(t.scaling_mode)) {
check_scales(DType::kFloat8E8M0);
} else if (is_nvfp4_scaling(t.scaling_mode)) {
check_scales(DType::kFloat8E4M3);
} else {
// Non-quantized types should not have scale/scale_inv
NVTE_CHECK(!t.scale_inv.has_data(), "Scale_inv not supported for non-quantized ", tensor_type,
" ", name);
NVTE_CHECK(!t.columnwise_scale_inv.has_data(), "Scale_inv not supported for non-quantized ",
tensor_type, " ", name);
}
}
void CheckInputGroupedTensor(const GroupedTensor &t, const std::string &name) {
NVTE_CHECK(t.has_data() || t.has_columnwise_data(), "Input grouped tensor ", name,
" not allocated");
CheckGroupedScaleInv(t, name, false);
CheckGroupedTensorShapeArrays(t, name);
}
void CheckOutputGroupedTensor(const GroupedTensor &t, const std::string &name, bool allow_empty) {
if (!allow_empty) {
NVTE_CHECK(t.has_data() || t.has_columnwise_data(), "Output grouped tensor ", name,
" not allocated");
}
// Only perform dtype-specific validation if data is allocated
if (t.has_data() || t.has_columnwise_data()) {
// Amax validation for delayed scaling
if (is_fp8_dtype(t.dtype()) && t.scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
NVTE_CHECK(t.amax.has_data(), "Output ", name, " amax must be allocated");
NVTE_CHECK(t.amax.dtype == DType::kFloat32, "Output ", name, " amax must be Float32");
}
CheckGroupedScaleInv(t, name, true);
}
CheckGroupedTensorShapeArrays(t, name);
}
class TensorAllocator {
public:
static TensorAllocator &instance() {
......@@ -387,6 +509,89 @@ Tensor *convertNVTETensorCheck(const NVTETensor t) {
return ptr;
}
// GroupedTensor allocator - similar pattern to TensorAllocator
class GroupedTensorAllocator {
public:
static GroupedTensorAllocator &instance() {
static GroupedTensorAllocator allocator;
return allocator;
}
~GroupedTensorAllocator() {}
NVTEGroupedTensor Allocate(NVTEScalingMode mode, size_t num_tensors, NVTEShape logical_shape) {
std::lock_guard<std::mutex> lock(mutex);
if (!free_list.empty()) {
uintptr_t index = free_list.back();
NVTEGroupedTensor ret = reinterpret_cast<NVTEGroupedTensor>(index);
free_list.pop_back();
// 1-based indexing - fully reinitialize the tensor to avoid stale data
memory[index - 1].scaling_mode = mode;
memory[index - 1].num_tensors = num_tensors;
memory[index - 1].logical_shape = logical_shape;
memory[index - 1].nvte_tensor = ret;
return ret;
}
if (memory.size() < memory.capacity()) {
memory.emplace_back(mode, num_tensors);
GroupedTensor &t = memory.back();
size = memory.size();
// 1-based indexing
uintptr_t index = memory.size();
t.logical_shape = logical_shape;
t.nvte_tensor = reinterpret_cast<NVTEGroupedTensor>(index);
return reinterpret_cast<NVTEGroupedTensor>(index);
}
NVTE_ERROR(
"Cannot allocate a new NVTEGroupedTensor. Maximum number of grouped tensors reached: ",
MAX_GROUPED_TENSOR_NUM, ". There is probably a memory leak in your application.");
}
void Free(NVTEGroupedTensor t) {
std::lock_guard<std::mutex> lock(mutex);
uintptr_t index = reinterpret_cast<uintptr_t>(t);
if (index == 0) return;
NVTE_CHECK(index <= memory.size(), "Invalid grouped tensor.");
free_list.push_back(index);
// Clean up
memory[index - 1].clear();
}
GroupedTensor *convertNVTEGroupedTensor(NVTEGroupedTensor t) {
uintptr_t index = reinterpret_cast<uintptr_t>(t);
// 1-based indexing to enable 0-initialization of NVTEGroupedTensor
// to be invalid tensor
static_assert(nullptr == 0);
if (index != 0 && index <= size) {
return &(memory[index - 1]);
}
return nullptr;
}
private:
GroupedTensorAllocator() {
std::lock_guard<std::mutex> lock(mutex);
memory.reserve(MAX_GROUPED_TENSOR_NUM);
}
std::mutex mutex;
std::atomic<size_t> size;
// Allocate at most 20 MB for grouped tensors
const size_t MAX_GROUPED_TENSOR_NUM = 20 * 1024 * 1024 / sizeof(GroupedTensor);
std::vector<uintptr_t> free_list;
std::vector<GroupedTensor> memory;
};
GroupedTensor *convertNVTEGroupedTensor(const NVTEGroupedTensor t) {
return GroupedTensorAllocator::instance().convertNVTEGroupedTensor(t);
}
GroupedTensor *convertNVTEGroupedTensorCheck(const NVTEGroupedTensor t) {
GroupedTensor *ptr = GroupedTensorAllocator::instance().convertNVTEGroupedTensor(t);
NVTE_CHECK(ptr != nullptr, "Invalid grouped tensor.");
return ptr;
}
} // namespace transformer_engine
NVTETensor nvte_create_tensor(NVTEScalingMode scaling_mode) {
......@@ -730,3 +935,132 @@ int nvte_is_non_tn_fp8_gemm_supported() {
});
return cache[device_id];
}
// Grouped Tensor C API implementations
NVTEGroupedTensor nvte_create_grouped_tensor(NVTEScalingMode scaling_mode, size_t num_tensors,
NVTEShape logical_shape) {
NVTE_CHECK(num_tensors > 0, "Number of tensors must be greater than 0");
NVTE_CHECK(logical_shape.ndim == 2, "Logical shape must be 2D");
NVTE_CHECK(logical_shape.data[0] > 0 && logical_shape.data[1] > 0,
"Logical shape must have positive dimensions");
NVTEGroupedTensor ret = transformer_engine::GroupedTensorAllocator::instance().Allocate(
scaling_mode, num_tensors, logical_shape);
return ret;
}
void nvte_destroy_grouped_tensor(NVTEGroupedTensor tensor) {
transformer_engine::GroupedTensorAllocator::instance().Free(tensor);
}
void nvte_set_grouped_tensor_param(NVTEGroupedTensor *tensor, NVTEGroupedTensorParam param_name,
const NVTEBasicTensor *param) {
NVTE_CHECK(tensor != nullptr, "Grouped tensor pointer can't be NULL.");
auto *t = transformer_engine::convertNVTEGroupedTensor(*tensor);
NVTE_CHECK(t != nullptr, "Grouped tensor is not allocated.");
NVTE_CHECK(param != nullptr, "Grouped tensor param can't be NULL.");
switch (param_name) {
case kNVTEGroupedRowwiseData:
t->data = *param;
break;
case kNVTEGroupedColumnwiseData:
t->columnwise_data = *param;
break;
case kNVTEGroupedScale:
t->scale = *param;
break;
case kNVTEGroupedAmax:
t->amax = *param;
break;
case kNVTEGroupedRowwiseScaleInv:
t->scale_inv = *param;
break;
case kNVTEGroupedColumnwiseScaleInv:
t->columnwise_scale_inv = *param;
break;
case kNVTEGroupedColumnwiseAmax:
t->columnwise_amax = *param;
break;
case kNVTEGroupedFirstDims:
t->first_dims = *param;
// Validate it's Int64
NVTE_CHECK(t->first_dims.dtype == transformer_engine::DType::kInt64,
"first_dims must have dtype Int64");
break;
case kNVTEGroupedLastDims:
t->last_dims = *param;
// Validate it's Int64
NVTE_CHECK(t->last_dims.dtype == transformer_engine::DType::kInt64,
"last_dims must have dtype Int64");
break;
case kNVTEGroupedTensorOffsets:
t->tensor_offsets = *param;
// Validate it's Int64
NVTE_CHECK(t->tensor_offsets.dtype == transformer_engine::DType::kInt64,
"tensor_offsets must have dtype Int64");
break;
default:
NVTE_ERROR("Unknown grouped tensor parameter!");
}
}
NVTEBasicTensor nvte_get_grouped_tensor_param(const NVTEGroupedTensor tensor,
NVTEGroupedTensorParam param_name) {
if (tensor == nullptr) {
return {nullptr, kNVTEFloat32, nvte_make_shape(nullptr, 0)};
}
const auto &t = *transformer_engine::convertNVTEGroupedTensorCheck(tensor);
switch (param_name) {
case kNVTEGroupedRowwiseData:
return t.data;
case kNVTEGroupedColumnwiseData:
return t.columnwise_data;
case kNVTEGroupedScale:
return t.scale;
case kNVTEGroupedAmax:
return t.amax;
case kNVTEGroupedRowwiseScaleInv:
return t.scale_inv;
case kNVTEGroupedColumnwiseScaleInv:
return t.columnwise_scale_inv;
case kNVTEGroupedColumnwiseAmax:
return t.columnwise_amax;
case kNVTEGroupedFirstDims:
return t.first_dims;
case kNVTEGroupedLastDims:
return t.last_dims;
case kNVTEGroupedTensorOffsets:
return t.tensor_offsets;
default:
NVTE_ERROR("Unknown grouped tensor parameter!");
}
}
size_t nvte_grouped_tensor_num_tensors(const NVTEGroupedTensor tensor) {
auto *t = transformer_engine::convertNVTEGroupedTensor(tensor);
if (t == nullptr) return 0;
return t->num_tensors;
}
NVTEDType nvte_grouped_tensor_type(const NVTEGroupedTensor tensor) {
auto *t = transformer_engine::convertNVTEGroupedTensor(tensor);
if (t == nullptr) return kNVTEFloat32;
return static_cast<NVTEDType>(t->dtype());
}
NVTEScalingMode nvte_grouped_tensor_scaling_mode(const NVTEGroupedTensor tensor) {
if (tensor == nullptr) {
return NVTE_DELAYED_TENSOR_SCALING;
}
const auto &t = *transformer_engine::convertNVTEGroupedTensorCheck(tensor);
return t.scaling_mode;
}
NVTEShape nvte_get_grouped_tensor_logical_shape(const NVTEGroupedTensor tensor) {
if (tensor == nullptr) {
return nvte_make_shape(nullptr, 0);
}
const auto &t = *transformer_engine::convertNVTEGroupedTensorCheck(tensor);
return t.logical_shape;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment