Unverified Commit 61822061 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[Core] Fix inconsistent logic in C++ tensor class (#2330)



* Initialize empty tensors with shape=[0] instead of shape=[].
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix runtime crash in LayerNorm

Still seeing correctness issues.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Make sure norm workspace sizes are not zero
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove assumption in swizzle kernel that data is available.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove assumption in multi-swizzle kernel that data is available.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Remove unnecessary explicit call to default constructor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Avoid accessing tensor data pointer if tensor has no entries
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Apply suggestions from code review
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update transformer_engine/common/swizzle/swizzle.cu
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Review suggestions from @ptrendx and @greptile-apps
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Prefer using row-wise/col-wise shape based on which has data
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix merge conflict, expand docs, fix inconsistency in dim function
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Change Tensor::has_data to check whether tensor is initialized, not whether pointer is valid.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Review suggestion from @greptile-apps
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Debug incorrect tensor initialization in tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Clarify comments that has_data does not guarantee safe pointer accesses
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug test failure when computing amaxes
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
parent d126cdd6
...@@ -278,52 +278,33 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape, ...@@ -278,52 +278,33 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
Tensor::Tensor(const std::string& name, Tensor::Tensor(const std::string& name,
const NVTEShape &shape, const DType type, const NVTEShape &shape, const DType type,
const bool rowwise, const bool columnwise, const bool rowwise, const bool columnwise,
const NVTEScalingMode &scaling_mode) { const NVTEScalingMode &scaling_mode)
name_ = name; : tensor_(scaling_mode), rowwise_{rowwise}, columnwise_{columnwise}, name_{name} {
// Initialize RNG
const size_t seed = create_seed_from_tensor_name(name); const size_t seed = create_seed_from_tensor_name(name);
gen_.seed(seed); gen_.seed(seed);
rowwise_ = rowwise;
columnwise_ = columnwise; // Make sure shape is valid
size_t total_size = bytes(shape, type);
void *dptr_rowwise = nullptr;
void *dptr_columnwise = nullptr;
cpu_data_rowwise_ = nullptr;
cpu_data_columnwise_ = nullptr;
amax_cpu_data_ = nullptr;
scale_cpu_data_ = nullptr;
rowwise_scale_inv_cpu_data_ = nullptr;
columnwise_scale_inv_cpu_data_ = nullptr;
float *amax = nullptr, *scale = nullptr;
float *rowwise_scale_inv = nullptr, *columnwise_scale_inv = nullptr;
if (columnwise) { if (columnwise) {
NVTE_CHECK(shape.ndim >= 2); NVTE_CHECK(shape.ndim >= 2);
} }
std::vector<size_t> normalized_shape_v = {product(shape, 0, shape.ndim - 1),
shape.data[shape.ndim - 1]};
NVTEShape normalized_shape = convertShape(normalized_shape_v);
NVTEShape columnwise_shape = {};
std::vector<size_t> columnwise_shape_vec; // Shape after flattening to 2D
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING NVTEShape flattened_shape;
|| scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) { {
// Transpose when tensor scaling std::vector<size_t> flattened_shape_vec;
columnwise_shape_vec.emplace_back(shape.data[shape.ndim - 1]); if (shape.ndim > 0) {
for (size_t i = 0; i < shape.ndim - 1; ++i) { flattened_shape_vec.push_back(product(shape, 0, shape.ndim - 1));
columnwise_shape_vec.emplace_back(shape.data[i]); flattened_shape_vec.push_back(shape.data[shape.ndim - 1]);
}
} else { } else {
// Same shape for MX and NVFP4 flattened_shape_vec.resize(2, 1);
for (size_t i = 0; i < shape.ndim; ++i) {
columnwise_shape_vec.emplace_back(shape.data[i]);
} }
flattened_shape = convertShape(flattened_shape_vec);
} }
if (columnwise) { // Allocate and initialize data
columnwise_shape = nvte_make_shape(columnwise_shape_vec.data(), columnwise_shape_vec.size()); void *dptr_rowwise = nullptr, *dptr_columnwise = nullptr;
} const size_t total_size = bytes(shape, type);
tensor_ = TensorWrapper(scaling_mode);
if (total_size != 0) { if (total_size != 0) {
if (rowwise) { if (rowwise) {
cudaMalloc((void**)&dptr_rowwise, total_size); // NOLINT(*) cudaMalloc((void**)&dptr_rowwise, total_size); // NOLINT(*)
...@@ -339,11 +320,51 @@ Tensor::Tensor(const std::string& name, ...@@ -339,11 +320,51 @@ Tensor::Tensor(const std::string& name,
} }
} }
// Set tensor row-wise data
if (rowwise) {
const DType rowwise_type = (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat4E2M1 : type; const DType rowwise_type = (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat4E2M1 : type;
const DType colwise_type = (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat4E2M1 : type;
tensor_.set_rowwise_data(dptr_rowwise, rowwise_type, shape); tensor_.set_rowwise_data(dptr_rowwise, rowwise_type, shape);
}
// Set tensor column-wise data
if (columnwise) {
// Determine shape of column-wise data
std::vector<size_t> columnwise_shape_vec;
switch (scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING:
case NVTE_BLOCK_SCALING_1D:
case NVTE_BLOCK_SCALING_2D: {
// Column-wise data shape is transposed
if (shape.ndim > 0) {
columnwise_shape_vec.emplace_back(shape.data[shape.ndim - 1]);
for (size_t i = 0; i < shape.ndim - 1; ++i) {
columnwise_shape_vec.emplace_back(shape.data[i]);
}
}
break;
}
case NVTE_MXFP8_1D_SCALING:
case NVTE_NVFP4_1D_SCALING: {
// Column-wise data matches shape
for (size_t i = 0; i < shape.ndim; ++i) {
columnwise_shape_vec.emplace_back(shape.data[i]);
}
break;
}
default:
NVTE_ERROR("Unrecognized scaling mode (", (size_t)scaling_mode, ").");
}
const auto columnwise_shape = nvte_make_shape(columnwise_shape_vec.data(),
columnwise_shape_vec.size());
// Set column-wise data buffer
const DType colwise_type = (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat4E2M1 : type;
tensor_.set_columnwise_data(dptr_columnwise, colwise_type, columnwise_shape); tensor_.set_columnwise_data(dptr_columnwise, colwise_type, columnwise_shape);
}
// Configure scales, amaxes, and other tensor buffers
float *amax = nullptr, *scale = nullptr;
float *rowwise_scale_inv = nullptr, *columnwise_scale_inv = nullptr;
if (isFp8Type(type) || isFp4Type(type)) { if (isFp8Type(type) || isFp4Type(type)) {
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
cudaMalloc((void**)&amax, sizeof(float)); // NOLINT(*) cudaMalloc((void**)&amax, sizeof(float)); // NOLINT(*)
...@@ -375,7 +396,7 @@ Tensor::Tensor(const std::string& name, ...@@ -375,7 +396,7 @@ Tensor::Tensor(const std::string& name,
scale_cpu_data_ = std::make_shared<float>(0); scale_cpu_data_ = std::make_shared<float>(0);
tensor_.set_scale(scale, DType::kFloat32, std::vector<size_t>{1}); tensor_.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
} }
auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(normalized_shape, tensor_.scaling_mode()); auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(flattened_shape, tensor_.scaling_mode());
auto rowwise_scale_size = rowwise_scale_meta.bytes(); auto rowwise_scale_size = rowwise_scale_meta.bytes();
auto columnwise_scale_size = colwise_scale_meta.bytes(); auto columnwise_scale_size = colwise_scale_meta.bytes();
auto scale_shape = rowwise_scale_meta.shape; auto scale_shape = rowwise_scale_meta.shape;
......
...@@ -74,38 +74,49 @@ inline size_t product(const std::vector<size_t> &shape) { ...@@ -74,38 +74,49 @@ inline size_t product(const std::vector<size_t> &shape) {
return ret; return ret;
} }
size_t get_buffer_size_bytes(const size_t N, const DType buffer_dtype);
size_t get_buffer_size_bytes(const size_t dim_first, const size_t dim_last,
const DType buffer_dtype);
struct SimpleTensor { struct SimpleTensor {
void *dptr; void *dptr;
std::vector<size_t> shape; std::vector<size_t> shape;
DType dtype; DType dtype;
SimpleTensor(void *dptr, const std::vector<size_t> &shape, DType dtype) SimpleTensor(void *dptr, std::vector<size_t> shape, DType dtype)
: dptr(dptr), shape(shape), dtype(dtype) {} : dptr{dptr}, shape{std::move(shape)}, dtype{dtype} {}
SimpleTensor(const NVTEBasicTensor &tensor) // NOLINT SimpleTensor(const NVTEBasicTensor &tensor) // NOLINT
: dptr(tensor.data_ptr), : dptr(tensor.data_ptr),
shape(tensor.shape.data, tensor.shape.data + tensor.shape.ndim), shape(tensor.shape.data, tensor.shape.data + tensor.shape.ndim),
dtype(static_cast<DType>(tensor.dtype)) {} dtype(static_cast<DType>(tensor.dtype)) {}
SimpleTensor() : SimpleTensor(nullptr, {}, DType::kFloat32) {} SimpleTensor() : SimpleTensor(nullptr, std::vector<size_t>{0}, DType::kFloat32) {}
operator NVTEBasicTensor() const { operator NVTEBasicTensor() const {
return {dptr, static_cast<NVTEDType>(dtype), return {dptr, static_cast<NVTEDType>(dtype),
nvte_make_shape(this->shape.data(), this->shape.size())}; nvte_make_shape(this->shape.data(), this->shape.size())};
} }
size_t numel() const { /*! Number of tensor elements. */
size_t acc = 1; size_t numel() const { return product(shape); }
for (const auto &dim : shape) {
acc *= dim; /*! Whether the tensor is initialized.
} *
return acc; * Tensors with non-trivial shapes are considered initialized. This
} * means that there is no guarantee that the data pointer can be
bool has_data() const noexcept { return dptr != nullptr && numel() > 0; } * safely accessed.
*/
bool has_data() const { return !(dptr == nullptr && shape.size() == 1 && shape[0] == 0); }
/*! Buffer size in bytes. */
size_t buffer_size_bytes() const { return get_buffer_size_bytes(numel(), dtype); }
/*! Reset to uninitialized tensor. */
void clear() { void clear() {
dptr = nullptr; dptr = nullptr;
shape.resize(0); shape.resize(1);
shape[0] = 0;
dtype = DType::kFloat32; dtype = DType::kFloat32;
} }
}; };
...@@ -123,17 +134,9 @@ struct Tensor { ...@@ -123,17 +134,9 @@ struct Tensor {
NVTEScalingMode scaling_mode; NVTEScalingMode scaling_mode;
NVTETensor nvte_tensor; NVTETensor nvte_tensor;
Tensor() Tensor() : scaling_mode{NVTE_DELAYED_TENSOR_SCALING}, nvte_tensor{0} {}
: data(),
columnwise_data(),
amax(nullptr, {1}, DType::kFloat32),
columnwise_amax(nullptr, {1}, DType::kFloat32),
scale(nullptr, {1}, DType::kFloat32),
scale_inv(nullptr, {1}, DType::kFloat32),
columnwise_scale_inv(nullptr, {1}, DType::kFloat32),
scaling_mode(NVTE_DELAYED_TENSOR_SCALING),
nvte_tensor(0) {}
/*! Reset tensor data. */
void clear() { void clear() {
data.clear(); data.clear();
columnwise_data.clear(); columnwise_data.clear();
...@@ -147,65 +150,62 @@ struct Tensor { ...@@ -147,65 +150,62 @@ struct Tensor {
explicit operator NVTETensor() const noexcept { return nvte_tensor; } explicit operator NVTETensor() const noexcept { return nvte_tensor; }
/*! Number of tensor elements. */
size_t numel() const { size_t numel() const {
size_t acc = 1; if (!has_data() && has_columnwise_data()) {
for (const auto dim : shape()) { return product(columnwise_data.shape);
acc *= dim;
} }
return acc; return product(data.shape);
} }
// TODO(Tim): Change this to use data.has_data() /*! Whether the tensor data buffer is not uninitialized.
bool has_data() const noexcept { return data.dptr != nullptr; } *
* Buffers with non-trivial shapes are considered initialized. This
* means that there is no guarantee that the data pointer can be
* safely accessed.
*/
bool has_data() const { return data.has_data(); }
// Check for size (not just pointer) for 0-dim or no token cases. /*! Whether the tensor column-wise data buffer is not uninitialized.
// TODO(Tim): Change this to use columnwise_data.has_data() *
bool has_columnwise_data() const noexcept { * Buffers with non-trivial shapes are considered initialized. This
return columnwise_data.dptr != nullptr || columnwise_data.shape.size() != 0; * means that there is no guarantee that the data pointer can be
} * safely accessed.
*/
bool has_columnwise_data() const { return columnwise_data.has_data(); }
/*! Datatype of tensor elements. */
DType dtype() const { DType dtype() const {
if (has_data()) return data.dtype; if (!has_data() && has_columnwise_data()) {
if (has_columnwise_data()) return columnwise_data.dtype; return columnwise_data.dtype;
// Fallback, used e.g. in workspace }
return data.dtype; return data.dtype;
} }
/*! Number of tensor dimensions. */
size_t dim() const { size_t dim() const {
if (!has_data() && has_columnwise_data()) { if (!has_data() && has_columnwise_data()) {
return columnwise_data.shape.size(); return columnwise_data.shape.size();
} else {
return data.shape.size();
} }
return data.shape.size();
} }
std::vector<size_t> shape() const { /*! Tensor dimensions.
/* Note: We sometimes experience spurious compiler errors *
* (-Wstringop-overflow) from this function. It appears that GCC * This is the logical tensor shape. The underlying data may have a
* has some bugs with std::vector (see * different shape, e.g. the column-wise data for some tensor
* https://gcc.gnu.org/bugzilla/show_bug.cgi?id=109569). * formats are transposed.
*/ */
std::vector<size_t> shape() const {
// Each tensor format interprets its data differently
switch (scaling_mode) { switch (scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: case NVTE_DELAYED_TENSOR_SCALING:
case NVTE_BLOCK_SCALING_1D:
case NVTE_BLOCK_SCALING_2D:
case NVTE_NVFP4_1D_SCALING: { case NVTE_NVFP4_1D_SCALING: {
// Choose data buffer based on whether it is initialized // Row-wise data shape matches tensor logical shape,
// Note: Uninitialized buffers currently have shape=[]. // column-wise data shape is transpose of logical shape
// However, this is logically incorrect. 0-D tensors have 1 if (!has_data() && has_columnwise_data()) {
// entry, and uninitialized tensors should have shape=[0].
bool use_columnwise_shape = false;
if (data.dptr != nullptr) {
use_columnwise_shape = false;
} else if (columnwise_data.dptr != nullptr) {
use_columnwise_shape = true;
} else if (data.shape.size() != 0) {
use_columnwise_shape = false;
} else if (columnwise_data.shape.size() != 0) {
use_columnwise_shape = true;
}
// Infer shape based on data
if (use_columnwise_shape) {
// Column-wise data is transposed
std::vector<size_t> ret; std::vector<size_t> ret;
if (!columnwise_data.shape.empty()) { if (!columnwise_data.shape.empty()) {
ret.reserve(columnwise_data.shape.size()); ret.reserve(columnwise_data.shape.size());
...@@ -218,38 +218,16 @@ struct Tensor { ...@@ -218,38 +218,16 @@ struct Tensor {
} }
return data.shape; return data.shape;
} }
case NVTE_MXFP8_1D_SCALING: case NVTE_MXFP8_1D_SCALING: {
// Row-wise and column-wise data shapes both match tensor
// logical shape
if (!has_data() && has_columnwise_data()) { if (!has_data() && has_columnwise_data()) {
return columnwise_data.shape; return columnwise_data.shape;
} else {
return data.shape;
}
break;
case NVTE_BLOCK_SCALING_1D:
case NVTE_BLOCK_SCALING_2D: {
if (!has_data() && has_columnwise_data()) {
std::vector<size_t> shape;
size_t ndim = columnwise_data.shape.size();
shape.reserve(ndim);
for (size_t i = 0; i + 1 < ndim; ++i) {
shape.push_back(columnwise_data.shape[i + 1]);
}
if (ndim > 0) {
shape.push_back(columnwise_data.shape[0]);
} }
return shape;
} else {
// NOTE: We may have removed the data pointer from
// data by setting usage. In that case, we return
// the non-null shape. It is our best guess at the most
// recent shape.
return data.shape; return data.shape;
} }
break;
}
default: default:
NVTE_ERROR("Cannot parse tensor shape with scaling mode \"", to_string(scaling_mode), "\""); NVTE_ERROR("Cannot parse tensor shape with scaling mode \"", to_string(scaling_mode), "\"");
return {};
} }
} }
...@@ -347,10 +325,10 @@ struct GroupedTensor { ...@@ -347,10 +325,10 @@ struct GroupedTensor {
columnwise_amax(), columnwise_amax(),
scale(), scale(),
num_tensors(num_tensors), num_tensors(num_tensors),
first_dims(nullptr, {}, DType::kInt64), first_dims(nullptr, std::vector<size_t>{0}, DType::kInt64),
last_dims(nullptr, {}, DType::kInt64), last_dims(nullptr, std::vector<size_t>{0}, DType::kInt64),
tensor_offsets(nullptr, {}, DType::kInt64), tensor_offsets(nullptr, std::vector<size_t>{0}, DType::kInt64),
logical_shape(nvte_make_shape(nullptr, 0)), logical_shape(nvte_make_shape(nullptr, 1)),
scaling_mode(scaling_mode), scaling_mode(scaling_mode),
nvte_tensor(0) {} nvte_tensor(0) {}
...@@ -383,9 +361,9 @@ struct GroupedTensor { ...@@ -383,9 +361,9 @@ struct GroupedTensor {
} }
DType dtype() const { DType dtype() const {
if (has_data()) return data.dtype; if (!has_data() && has_columnwise_data()) {
if (has_columnwise_data()) return columnwise_data.dtype; return columnwise_data.dtype;
// Fallback, used e.g. in workspace or when allow_empty=true }
return data.dtype; return data.dtype;
} }
...@@ -400,7 +378,7 @@ struct GroupedTensor { ...@@ -400,7 +378,7 @@ struct GroupedTensor {
first_dims.clear(); first_dims.clear();
last_dims.clear(); last_dims.clear();
tensor_offsets.clear(); tensor_offsets.clear();
logical_shape = nvte_make_shape(nullptr, 0); logical_shape = nvte_make_shape(nullptr, 1);
num_tensors = 0; num_tensors = 0;
scaling_mode = NVTE_DELAYED_TENSOR_SCALING; scaling_mode = NVTE_DELAYED_TENSOR_SCALING;
nvte_tensor = 0; nvte_tensor = 0;
...@@ -869,10 +847,6 @@ inline bool is_aligned_tensor_data(const Tensor &t, size_t alignment) { ...@@ -869,10 +847,6 @@ inline bool is_aligned_tensor_data(const Tensor &t, size_t alignment) {
size_t typeToSize(const DType type); size_t typeToSize(const DType type);
size_t typeToNumBits(const DType type); size_t typeToNumBits(const DType type);
size_t get_buffer_size_bytes(const size_t N, const DType buffer_dtype);
size_t get_buffer_size_bytes(const size_t dim_first, const size_t dim_last,
const DType buffer_dtype);
void CheckNoopTensor(const Tensor &t, const std::string &name); void CheckNoopTensor(const Tensor &t, const std::string &name);
void CheckInputTensor(const Tensor &t, const std::string &name); void CheckInputTensor(const Tensor &t, const std::string &name);
void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty = false); void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty = false);
......
...@@ -142,7 +142,8 @@ void *nvte_tensor_columnwise_data(const NVTETensor tensor); ...@@ -142,7 +142,8 @@ void *nvte_tensor_columnwise_data(const NVTETensor tensor);
/*! \brief Construct a shape from an array of dimension sizes. /*! \brief Construct a shape from an array of dimension sizes.
* *
* \param[data] Pointer to start of shape array. * \param[data] Pointer to start of shape array. If NULL, the shape
* will be filled with zeros.
* \param[data] Number of dimensions (must be <= 14) * \param[data] Number of dimensions (must be <= 14)
* *
* \return A shape. The shape will own its own copy of the data. * \return A shape. The shape will own its own copy of the data.
...@@ -575,15 +576,22 @@ class TensorWrapper { ...@@ -575,15 +576,22 @@ class TensorWrapper {
*/ */
TensorWrapper(void *dptr, const NVTEShape &shape, const DType dtype, float *amax_dptr = nullptr, TensorWrapper(void *dptr, const NVTEShape &shape, const DType dtype, float *amax_dptr = nullptr,
float *scale_dptr = nullptr, float *scale_inv_dptr = nullptr, float *scale_dptr = nullptr, float *scale_inv_dptr = nullptr,
const NVTEShape scale_inv_shape = defaultShape, NVTEShape scale_inv_shape = defaultShape,
const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) { const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) {
tensor_ = nvte_create_tensor(scaling_mode); tensor_ = nvte_create_tensor(scaling_mode);
NVTEBasicTensor data = {dptr, static_cast<NVTEDType>(dtype), shape}; NVTEBasicTensor data = {dptr, static_cast<NVTEDType>(dtype), shape};
nvte_set_tensor_param(&tensor_, kNVTERowwiseData, &data); nvte_set_tensor_param(&tensor_, kNVTERowwiseData, &data);
NVTEBasicTensor amax = {amax_dptr, kNVTEFloat32, defaultShape}; NVTEBasicTensor amax = {amax_dptr, kNVTEFloat32,
amax_dptr != nullptr ? defaultShape : emptyShape};
nvte_set_tensor_param(&tensor_, kNVTEAmax, &amax); nvte_set_tensor_param(&tensor_, kNVTEAmax, &amax);
NVTEBasicTensor scale = {scale_dptr, kNVTEFloat32, defaultShape}; NVTEBasicTensor scale = {scale_dptr, kNVTEFloat32,
scale_dptr != nullptr ? defaultShape : emptyShape};
nvte_set_tensor_param(&tensor_, kNVTEScale, &scale); nvte_set_tensor_param(&tensor_, kNVTEScale, &scale);
if (scale_inv_dptr == nullptr && scale_inv_shape.ndim == defaultShape.ndim &&
scale_inv_shape.ndim == 1 && scale_inv_shape.data[0] == defaultShape.data[0]) {
// Scale-inv pointer has not been provided and shape matches default
scale_inv_shape = emptyShape;
}
NVTEBasicTensor scale_inv = {scale_inv_dptr, kNVTEFloat32, scale_inv_shape}; NVTEBasicTensor scale_inv = {scale_inv_dptr, kNVTEFloat32, scale_inv_shape};
nvte_set_tensor_param(&tensor_, kNVTERowwiseScaleInv, &scale_inv); nvte_set_tensor_param(&tensor_, kNVTERowwiseScaleInv, &scale_inv);
} }
...@@ -734,7 +742,7 @@ class TensorWrapper { ...@@ -734,7 +742,7 @@ class TensorWrapper {
*/ */
const NVTEShape shape() const noexcept { const NVTEShape shape() const noexcept {
if (tensor_ == nullptr) { if (tensor_ == nullptr) {
return nvte_make_shape(nullptr, 0); return emptyShape;
} }
return nvte_tensor_shape(tensor_); return nvte_tensor_shape(tensor_);
} }
...@@ -745,7 +753,7 @@ class TensorWrapper { ...@@ -745,7 +753,7 @@ class TensorWrapper {
*/ */
const NVTEShape columnwise_shape() const noexcept { const NVTEShape columnwise_shape() const noexcept {
if (tensor_ == nullptr) { if (tensor_ == nullptr) {
return nvte_make_shape(nullptr, 0); return emptyShape;
} }
return nvte_tensor_columnwise_shape(tensor_); return nvte_tensor_columnwise_shape(tensor_);
} }
...@@ -869,7 +877,7 @@ class TensorWrapper { ...@@ -869,7 +877,7 @@ class TensorWrapper {
*/ */
const NVTEShape scale_inv_shape() const noexcept { const NVTEShape scale_inv_shape() const noexcept {
if (tensor_ == nullptr) { if (tensor_ == nullptr) {
return nvte_make_shape(nullptr, 0); return emptyShape;
} }
return nvte_tensor_scale_inv_shape(tensor_); return nvte_tensor_scale_inv_shape(tensor_);
} }
...@@ -888,6 +896,7 @@ class TensorWrapper { ...@@ -888,6 +896,7 @@ class TensorWrapper {
static constexpr size_t defaultData = 1; static constexpr size_t defaultData = 1;
static constexpr NVTEShape defaultShape = { static constexpr NVTEShape defaultShape = {
{defaultData, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1}; {defaultData, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1};
static constexpr NVTEShape emptyShape = {{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1};
private: private:
NVTEShape convertShape(const NVTEShape &s) { return s; } NVTEShape convertShape(const NVTEShape &s) { return s; }
......
...@@ -127,7 +127,13 @@ void TeNormalizationPlan<KernelParamsType>::_build() { ...@@ -127,7 +127,13 @@ void TeNormalizationPlan<KernelParamsType>::_build() {
template <typename KernelParamsType> template <typename KernelParamsType>
std::vector<size_t> TeNormalizationPlan<KernelParamsType>::getWorkspaceShape() const { std::vector<size_t> TeNormalizationPlan<KernelParamsType>::getWorkspaceShape() const {
return {_launch_params.getTotalWorkspaceBytes(_is_layernorm)}; size_t workspace_size = _launch_params.getTotalWorkspaceBytes(_is_layernorm);
if (workspace_size == 0) {
// Workspace size must not be zero since that corresponds to a
// workspace size query
workspace_size = 1;
}
return {workspace_size};
} }
template <typename KernelParamsType> template <typename KernelParamsType>
...@@ -405,7 +411,13 @@ void CudnnNormalizationPlan::_build() { ...@@ -405,7 +411,13 @@ void CudnnNormalizationPlan::_build() {
} }
std::vector<size_t> CudnnNormalizationPlan::getWorkspaceShape() const { std::vector<size_t> CudnnNormalizationPlan::getWorkspaceShape() const {
return {static_cast<size_t>(_graph.get_workspace_size())}; size_t workspace_size = _graph.get_workspace_size();
if (workspace_size == 0) {
// Workspace size must not be zero since that corresponds to a
// workspace size query
workspace_size = 1;
}
return {workspace_size};
} }
void CudnnNormalizationPlan::execute(Tensor* z, void* x_dptr, void* gamma_dptr, void* beta_dptr, void CudnnNormalizationPlan::execute(Tensor* z, void* x_dptr, void* gamma_dptr, void* beta_dptr,
......
...@@ -51,7 +51,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size ...@@ -51,7 +51,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
"RSigma must be 1D tensor with shape (x.shape[0],)."); "RSigma must be 1D tensor with shape (x.shape[0],).");
NVTE_CHECK(rsigma->data.dtype == DType::kFloat32, "RSigma must be a float32 tensor."); NVTE_CHECK(rsigma->data.dtype == DType::kFloat32, "RSigma must be a float32 tensor.");
if (!workspace->data.shape.empty()) { if (workspace->data.numel() != 0) {
CheckInputTensor(x, "x"); CheckInputTensor(x, "x");
CheckInputTensor(gamma, "gamma"); CheckInputTensor(gamma, "gamma");
CheckInputTensor(beta, "beta"); CheckInputTensor(beta, "beta");
...@@ -94,7 +94,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size ...@@ -94,7 +94,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
multiprocessorCount, zero_centered_gamma, is_aligned, z->scaling_mode, training, multiprocessorCount, zero_centered_gamma, is_aligned, z->scaling_mode, training,
gamma_in_weight_dtype); gamma_in_weight_dtype);
if (workspace->data.shape.empty()) { if (workspace->data.numel() == 0) {
workspace->data.shape = plan->getWorkspaceShape(); workspace->data.shape = plan->getWorkspaceShape();
workspace->data.dtype = DType::kByte; workspace->data.dtype = DType::kByte;
return; return;
...@@ -146,7 +146,7 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te ...@@ -146,7 +146,7 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te
NVTE_CHECK(dbeta->data.shape == gamma.data.shape); NVTE_CHECK(dbeta->data.shape == gamma.data.shape);
NVTE_CHECK(dbeta->data.dtype == gamma.data.dtype); NVTE_CHECK(dbeta->data.dtype == gamma.data.dtype);
if (!workspace->data.shape.empty()) { if (workspace->data.numel() != 0) {
CheckInputTensor(dz, "dz"); CheckInputTensor(dz, "dz");
CheckInputTensor(x, "x"); CheckInputTensor(x, "x");
CheckInputTensor(mu, "mu"); CheckInputTensor(mu, "mu");
...@@ -179,7 +179,7 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te ...@@ -179,7 +179,7 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te
multiprocessorCount, zero_centered_gamma, is_aligned, NVTE_DELAYED_TENSOR_SCALING, true, multiprocessorCount, zero_centered_gamma, is_aligned, NVTE_DELAYED_TENSOR_SCALING, true,
gamma_in_weight_dtype); gamma_in_weight_dtype);
if (workspace->data.shape.empty()) { if (workspace->data.numel() == 0) {
workspace->data.shape = plan->getWorkspaceShape(); workspace->data.shape = plan->getWorkspaceShape();
workspace->data.dtype = DType::kByte; workspace->data.dtype = DType::kByte;
return; return;
......
...@@ -39,7 +39,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens ...@@ -39,7 +39,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
"RSigma must be 1D tensor with shape (x.shape[0],)."); "RSigma must be 1D tensor with shape (x.shape[0],).");
NVTE_CHECK(rsigma->data.dtype == DType::kFloat32, "RSigma must be a float32 tensor."); NVTE_CHECK(rsigma->data.dtype == DType::kFloat32, "RSigma must be a float32 tensor.");
if (!workspace->data.shape.empty()) { if (workspace->data.numel() != 0) {
CheckInputTensor(x, "x"); CheckInputTensor(x, "x");
CheckInputTensor(gamma, "gamma"); CheckInputTensor(gamma, "gamma");
...@@ -79,7 +79,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens ...@@ -79,7 +79,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
multiprocessorCount, zero_centered_gamma, is_aligned, z->scaling_mode, training, multiprocessorCount, zero_centered_gamma, is_aligned, z->scaling_mode, training,
gamma_in_weight_dtype); gamma_in_weight_dtype);
if (workspace->data.shape.empty()) { if (workspace->data.numel() == 0) {
workspace->data.shape = plan->getWorkspaceShape(); workspace->data.shape = plan->getWorkspaceShape();
workspace->data.dtype = DType::kByte; workspace->data.dtype = DType::kByte;
return; return;
...@@ -125,7 +125,7 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const ...@@ -125,7 +125,7 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const
NVTE_CHECK(dgamma->data.shape == gamma.data.shape); NVTE_CHECK(dgamma->data.shape == gamma.data.shape);
NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype); NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype);
if (!workspace->data.shape.empty()) { if (workspace->data.numel() != 0) {
CheckInputTensor(dz, "dz"); CheckInputTensor(dz, "dz");
CheckInputTensor(x, "x"); CheckInputTensor(x, "x");
CheckInputTensor(rsigma, "rsigma"); CheckInputTensor(rsigma, "rsigma");
...@@ -156,7 +156,7 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const ...@@ -156,7 +156,7 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const
multiprocessorCount, zero_centered_gamma, is_aligned, NVTE_DELAYED_TENSOR_SCALING, true, multiprocessorCount, zero_centered_gamma, is_aligned, NVTE_DELAYED_TENSOR_SCALING, true,
gamma_in_weight_dtype); gamma_in_weight_dtype);
if (workspace->data.shape.empty()) { if (workspace->data.numel() == 0) {
workspace->data.shape = plan->getWorkspaceShape(); workspace->data.shape = plan->getWorkspaceShape();
workspace->data.dtype = DType::kByte; workspace->data.dtype = DType::kByte;
return; return;
...@@ -191,7 +191,7 @@ void rmsnorm_bwd_add(const Tensor &dz, const Tensor &x, const Tensor &add, const ...@@ -191,7 +191,7 @@ void rmsnorm_bwd_add(const Tensor &dz, const Tensor &x, const Tensor &add, const
NVTE_CHECK(dgamma->data.shape == gamma.data.shape); NVTE_CHECK(dgamma->data.shape == gamma.data.shape);
NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype); NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype);
if (!workspace->data.shape.empty()) { if (workspace->data.numel() != 0) {
CheckInputTensor(dz, "dz"); CheckInputTensor(dz, "dz");
CheckInputTensor(x, "x"); CheckInputTensor(x, "x");
CheckInputTensor(add, "add"); CheckInputTensor(add, "add");
...@@ -222,7 +222,7 @@ void rmsnorm_bwd_add(const Tensor &dz, const Tensor &x, const Tensor &add, const ...@@ -222,7 +222,7 @@ void rmsnorm_bwd_add(const Tensor &dz, const Tensor &x, const Tensor &add, const
multiprocessorCount, zero_centered_gamma, is_aligned, NVTE_DELAYED_TENSOR_SCALING, true, multiprocessorCount, zero_centered_gamma, is_aligned, NVTE_DELAYED_TENSOR_SCALING, true,
gamma_in_weight_dtype); gamma_in_weight_dtype);
if (workspace->data.shape.empty()) { if (workspace->data.numel() == 0) {
workspace->data.shape = plan->getWorkspaceShape(); workspace->data.shape = plan->getWorkspaceShape();
workspace->data.dtype = DType::kByte; workspace->data.dtype = DType::kByte;
return; return;
......
...@@ -332,68 +332,118 @@ __global__ void multi_tensor_swizzle_col_scaling_kernel(MultiSwizzleArgs kernel_ ...@@ -332,68 +332,118 @@ __global__ void multi_tensor_swizzle_col_scaling_kernel(MultiSwizzleArgs kernel_
} // namespace } // namespace
void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t stream) { void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t stream) {
NVTE_CHECK( // Check scaling mode
input->scaling_mode == NVTE_MXFP8_1D_SCALING || input->scaling_mode == NVTE_NVFP4_1D_SCALING, const auto& scaling_mode = input->scaling_mode;
NVTE_CHECK(scaling_mode == NVTE_MXFP8_1D_SCALING || scaling_mode == NVTE_NVFP4_1D_SCALING,
"Input tensor has invalid scaling mode (", to_string(input->scaling_mode), ")."); "Input tensor has invalid scaling mode (", to_string(input->scaling_mode), ").");
NVTE_CHECK(is_fp8_dtype(input->dtype()) || is_fp4_dtype(input->dtype()),
"Input tensor has invalid dtype (", to_string(input->dtype()), ").");
// Do nothing if tensor is empty
if (input->data.numel() == 0) {
return;
}
// Check tensors
CheckInputTensor(*input, "scaling_factor_input"); CheckInputTensor(*input, "scaling_factor_input");
CheckInputTensor(*output, "scaling_factor_output"); CheckInputTensor(*output, "scaling_factor_output");
switch (scaling_mode) {
case NVTE_MXFP8_1D_SCALING:
NVTE_CHECK(is_fp8_dtype(input->dtype()), "Input tensor has invalid dtype (expected FP8, got ",
to_string(input->dtype()), ").");
break;
case NVTE_NVFP4_1D_SCALING:
NVTE_CHECK(is_fp4_dtype(input->dtype()), "Input tensor has invalid dtype (expected FP4, got ",
to_string(input->dtype()), ").");
break;
default:
NVTE_ERROR("Invalid scaling mode");
}
auto& scaling_mode = input->scaling_mode; // Check if scaling factors are non-trivial
NVTE_CHECK(scaling_mode == NVTE_MXFP8_1D_SCALING || scaling_mode == NVTE_NVFP4_1D_SCALING, const bool has_rowwise_scale_inv = input->scale_inv.has_data();
"Unsupported scaling mode for swizzling."); const bool has_columnwise_scale_inv = input->columnwise_scale_inv.has_data();
NVTE_CHECK(!has_rowwise_scale_inv || !has_columnwise_scale_inv,
bool nvfp4 = scaling_mode == NVTE_NVFP4_1D_SCALING; "Input tensor has both row-wise and column-wise scaling factors");
if (!has_rowwise_scale_inv && !has_columnwise_scale_inv) {
return;
}
// 1D block scaling, row-wise or colum-wise // Deduce tensor dims
int m, k; int m{0}, k{0};
if (input->has_data()) { switch (scaling_mode) {
case NVTE_MXFP8_1D_SCALING: {
if (has_rowwise_scale_inv) {
NVTE_CHECK(input->scale_inv.shape.size() == 2,
"Expected 2D scaling factors, got shape=", input->scale_inv.shape, ".");
m = input->scale_inv.shape[0]; m = input->scale_inv.shape[0];
k = input->scale_inv.shape[1]; k = input->scale_inv.shape[1];
} else { } else if (has_columnwise_scale_inv) {
if (nvfp4) { NVTE_CHECK(input->columnwise_scale_inv.shape.size() == 2,
m = input->columnwise_scale_inv.shape[0]; "Expected 2D scaling factors, got shape=", input->columnwise_scale_inv.shape,
k = input->columnwise_scale_inv.shape[1]; ".");
} else {
m = input->columnwise_scale_inv.shape[1]; m = input->columnwise_scale_inv.shape[1];
k = input->columnwise_scale_inv.shape[0]; k = input->columnwise_scale_inv.shape[0];
} }
break;
}
case NVTE_NVFP4_1D_SCALING: {
if (has_rowwise_scale_inv) {
NVTE_CHECK(input->scale_inv.shape.size() == 2,
"Expected 2D scaling factors, got shape=", input->scale_inv.shape, ".");
m = input->scale_inv.shape[0];
k = input->scale_inv.shape[1];
} else if (has_columnwise_scale_inv) {
NVTE_CHECK(input->columnwise_scale_inv.shape.size() == 2,
"Expected 2D scaling factors, got shape=", input->columnwise_scale_inv.shape,
".");
m = input->columnwise_scale_inv.shape[0];
k = input->columnwise_scale_inv.shape[1];
}
break;
}
default:
NVTE_ERROR("Invalid scaling mode");
} }
// Check dims
constexpr int SF_TILE_DIM_M = 128; constexpr int SF_TILE_DIM_M = 128;
constexpr int SF_TILE_DIM_K = 4; constexpr int SF_TILE_DIM_K = 4;
NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!"); NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!");
NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!"); NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!");
NVTE_CHECK(k > 0, "Input scale inverse should be 2D!");
if (output->has_data()) { // Check that output tensor matches input tensor
NVTE_CHECK(m * k == std::accumulate(output->scale_inv.shape.begin(), if (has_rowwise_scale_inv) {
output->scale_inv.shape.end(), 1, std::multiplies<int>()), NVTE_CHECK(output->scale_inv.has_data(),
"Input.scale_inv size is not equal to Output.scale_inv size!"); "Output tensor does not have row-wise scaling factors.");
} NVTE_CHECK(m * k == output->scale_inv.numel(), "Expected output tensor to have ", m * k,
if (output->has_columnwise_data()) { " row-wise scaling factors, but got shape=", output->scale_inv.shape, ".");
NVTE_CHECK(m * k == std::accumulate(output->columnwise_scale_inv.shape.begin(), }
output->columnwise_scale_inv.shape.end(), 1, if (has_columnwise_scale_inv) {
std::multiplies<int>()), NVTE_CHECK(output->columnwise_scale_inv.has_data(),
"Input.columnwise_scale_inv size is not equal to " "Output tensor does not have column-wise scaling factors.");
"Output.columnwise_scale_inv size!"); NVTE_CHECK(
m * k == output->columnwise_scale_inv.numel(), "Expected output tensor to have ", m * k,
" column-wise scaling factors, but got shape=", output->columnwise_scale_inv.shape, ".");
} }
int num_tiles_m = m / SF_TILE_DIM_M; // Choose swizzle implementation
int num_tiles_k = k / SF_TILE_DIM_K; bool rowwise_swizzle{false}, columnwise_swizzle{false};
switch (scaling_mode) {
case NVTE_MXFP8_1D_SCALING: {
rowwise_swizzle = has_rowwise_scale_inv;
columnwise_swizzle = has_columnwise_scale_inv;
break;
}
case NVTE_NVFP4_1D_SCALING: {
// NVFP4 column-wise data is transposed, so row-wise and
// column-wise scales have same swizzling format
rowwise_swizzle = true;
columnwise_swizzle = false;
break;
}
default:
NVTE_ERROR("Invalid scaling mode");
}
// For NVFP4, the scale inverse for tranposed data needs rowwise swizzle. const dim3 block_size(TB_DIM, TB_DIM);
const bool rowwise_swizzle = input->has_data() || nvfp4; const int num_tiles_m = m / SF_TILE_DIM_M;
const bool columnwise_swizzle = input->has_columnwise_data() && !nvfp4; const int num_tiles_k = k / SF_TILE_DIM_K;
dim3 block_size(TB_DIM, TB_DIM); // Perform row-wise swizzle
if (rowwise_swizzle) { if (rowwise_swizzle) {
int vec_load_size = (num_tiles_k - 1) % 4 + 1; int vec_load_size = (num_tiles_k - 1) % 4 + 1;
/* there is no int3 and misaligned if using int4/int2 */ /* there is no int3 and misaligned if using int4/int2 */
...@@ -402,21 +452,33 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s ...@@ -402,21 +452,33 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
dim3 num_blocks(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m); dim3 num_blocks(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m);
int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t);
int original_M, original_K; int original_M{0}, original_K{0};
void *input_scale_inv_ptr, *output_scale_inv_ptr; void *input_scale_inv_ptr{nullptr}, *output_scale_inv_ptr{nullptr};
switch (scaling_mode) {
if (!nvfp4 || input->has_data()) { case NVTE_MXFP8_1D_SCALING: {
int block_scale_size = nvfp4 ? NVFP4_BLOCK_SIZE : MXFP8_BLOCK_SIZE;
original_M = input->flat_first_dim(); original_M = input->flat_first_dim();
original_K = input->flat_last_dim() / block_scale_size; original_K = input->flat_last_dim() / MXFP8_BLOCK_SIZE;
input_scale_inv_ptr = input->scale_inv.dptr; input_scale_inv_ptr = input->scale_inv.dptr;
output_scale_inv_ptr = output->scale_inv.dptr; output_scale_inv_ptr = output->scale_inv.dptr;
} else { break;
}
case NVTE_NVFP4_1D_SCALING: {
if (has_rowwise_scale_inv) {
original_M = input->flat_first_dim();
original_K = input->flat_last_dim() / NVFP4_BLOCK_SIZE;
input_scale_inv_ptr = input->scale_inv.dptr;
output_scale_inv_ptr = output->scale_inv.dptr;
} else if (has_columnwise_scale_inv) {
original_M = input->flat_last_dim(); original_M = input->flat_last_dim();
original_K = input->flat_first_dim() / NVFP4_BLOCK_SIZE; original_K = input->flat_first_dim() / NVFP4_BLOCK_SIZE;
input_scale_inv_ptr = input->columnwise_scale_inv.dptr; input_scale_inv_ptr = input->columnwise_scale_inv.dptr;
output_scale_inv_ptr = output->columnwise_scale_inv.dptr; output_scale_inv_ptr = output->columnwise_scale_inv.dptr;
} }
break;
}
default:
NVTE_ERROR("Invalid scaling mode");
}
switch (vec_load_size) { switch (vec_load_size) {
case 4: case 4:
...@@ -447,7 +509,10 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s ...@@ -447,7 +509,10 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
NVTE_ERROR("Not valid vec_load_size."); NVTE_ERROR("Not valid vec_load_size.");
break; break;
} }
NVTE_CHECK_CUDA(cudaGetLastError());
} }
// Perform column-wise swizzle
if (columnwise_swizzle) { if (columnwise_swizzle) {
int vec_load_size = (num_tiles_m - 1) % 4 + 1; int vec_load_size = (num_tiles_m - 1) % 4 + 1;
if (vec_load_size == 3) vec_load_size = 1; /* no int3 and misaligned if using int4/int2 */ if (vec_load_size == 3) vec_load_size = 1; /* no int3 and misaligned if using int4/int2 */
...@@ -456,8 +521,6 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s ...@@ -456,8 +521,6 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t);
const int original_M = input->flat_last_dim(); const int original_M = input->flat_last_dim();
const int original_K = input->flat_first_dim() / MXFP8_BLOCK_SIZE; const int original_K = input->flat_first_dim() / MXFP8_BLOCK_SIZE;
// NVFP4 shouldn't end up here because it only needs rowwise swizzle
NVTE_CHECK(!nvfp4, "NVFP4 shouldn't end up here because it only needs rowwise swizzle");
switch (vec_load_size) { switch (vec_load_size) {
case 4: case 4:
...@@ -491,9 +554,8 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s ...@@ -491,9 +554,8 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
NVTE_ERROR("Not valid vec_load_size."); NVTE_ERROR("Not valid vec_load_size.");
break; break;
} }
}
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
}
} }
template <int SF_TILE_DIM_M, int SF_TILE_DIM_K> template <int SF_TILE_DIM_M, int SF_TILE_DIM_K>
...@@ -595,17 +657,18 @@ void multi_tensor_swizzle_scaling_factors(const std::vector<Tensor*>& input, ...@@ -595,17 +657,18 @@ void multi_tensor_swizzle_scaling_factors(const std::vector<Tensor*>& input,
(is_fp8 && is_mxfp8_scaling(scaling_mode)) || (is_fp4 && is_nvfp4_scaling(scaling_mode)), (is_fp8 && is_mxfp8_scaling(scaling_mode)) || (is_fp4 && is_nvfp4_scaling(scaling_mode)),
"Not implemented scaling mode " + to_string(scaling_mode) + "."); "Not implemented scaling mode " + to_string(scaling_mode) + ".");
// We don't allow empty tensors. They should be filtered out before calling this function. // We don't allow empty tensors. They should be filtered out before calling this function.
if (input[i]->data.numel() == 0) { NVTE_CHECK(input[i]->numel() != 0, "Tensor input[", i, "] is empty.");
NVTE_ERROR("Tensor input[" + std::to_string(i) + "] is empty.");
}
CheckInputTensor(*input[i], "scaling_factor_input[" + std::to_string(i) + "]"); CheckInputTensor(*input[i], "scaling_factor_input[" + std::to_string(i) + "]");
CheckInputTensor(*output[i], "scaling_factor_output[" + std::to_string(i) + "]"); CheckInputTensor(*output[i], "scaling_factor_output[" + std::to_string(i) + "]");
all_has_data &= input[i]->has_data(); all_has_data = all_has_data && input[i]->scale_inv.has_data();
all_has_columnwise_data &= input[i]->has_columnwise_data(); all_has_columnwise_data =
all_nvfp4 &= is_nvfp4_scaling(scaling_mode); (all_has_columnwise_data && input[i]->columnwise_scale_inv.has_data());
all_nvfp4 = all_nvfp4 && is_nvfp4_scaling(scaling_mode);
} }
NVTE_CHECK(all_has_data || all_has_columnwise_data, NVTE_CHECK(all_has_data || all_has_columnwise_data,
"All tensors should have data or columnwise data."); "All tensors should have data or columnwise data.");
NVTE_CHECK(!all_has_data || !all_has_columnwise_data,
"All tensors have both data and columnwise data.");
const bool rowwise_swizzle = all_has_data || all_nvfp4; const bool rowwise_swizzle = all_has_data || all_nvfp4;
const bool columnwise_swizzle = all_has_columnwise_data && !all_nvfp4; const bool columnwise_swizzle = all_has_columnwise_data && !all_nvfp4;
...@@ -644,18 +707,19 @@ void multi_tensor_swizzle_scaling_factors(const std::vector<Tensor*>& input, ...@@ -644,18 +707,19 @@ void multi_tensor_swizzle_scaling_factors(const std::vector<Tensor*>& input,
NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!"); NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!");
NVTE_CHECK(k > 0, "Input scale inverse should be 2D!"); NVTE_CHECK(k > 0, "Input scale inverse should be 2D!");
if (output[i]->has_data()) { if (all_has_data) {
NVTE_CHECK( NVTE_CHECK(output[i]->scale_inv.has_data(), "Output tensor ", i,
m * k == std::accumulate(output[i]->scale_inv.shape.begin(), " does not have row-wise scaling factors.");
output[i]->scale_inv.shape.end(), 1, std::multiplies<int>()), NVTE_CHECK(m * k == output[i]->scale_inv.numel(), "Expected output tensor ", i, " to have ",
"Input.scale_inv size is not equal to Output.scale_inv size!"); m * k, " row-wise scaling factors, but got shape=", output[i]->scale_inv.shape,
} ".");
if (output[i]->has_columnwise_data()) { }
NVTE_CHECK(m * k == std::accumulate(output[i]->columnwise_scale_inv.shape.begin(), if (all_has_columnwise_data) {
output[i]->columnwise_scale_inv.shape.end(), 1, NVTE_CHECK(output[i]->columnwise_scale_inv.has_data(), "Output tensor ", i,
std::multiplies<int>()), " does not have column-wise scaling factors.");
"Input.columnwise_scale_inv size is not equal to " NVTE_CHECK(m * k == output[i]->columnwise_scale_inv.numel(), "Expected output tensor ", i,
"Output.columnwise_scale_inv size!"); " to have ", m * k, " column-wise scaling factors, but got shape=",
output[i]->columnwise_scale_inv.shape, ".");
} }
int num_tiles_k = k / SF_TILE_DIM_K; int num_tiles_k = k / SF_TILE_DIM_K;
......
...@@ -77,7 +77,7 @@ std::string to_string(const NVTEScalingMode &mode) { ...@@ -77,7 +77,7 @@ std::string to_string(const NVTEScalingMode &mode) {
} }
void CheckNoopTensor(const Tensor &t, const std::string &name) { void CheckNoopTensor(const Tensor &t, const std::string &name) {
if (t.data.dptr != nullptr) { if (t.data.has_data()) {
NVTE_CHECK(t.numel() == 1, "Expected 1 element for ", name, " noop, but found ", t.numel(), NVTE_CHECK(t.numel() == 1, "Expected 1 element for ", name, " noop, but found ", t.numel(),
"."); ".");
NVTE_CHECK(t.data.dtype == DType::kFloat32, "Found wrong dtype for ", name, NVTE_CHECK(t.data.dtype == DType::kFloat32, "Found wrong dtype for ", name,
...@@ -88,16 +88,31 @@ void CheckNoopTensor(const Tensor &t, const std::string &name) { ...@@ -88,16 +88,31 @@ void CheckNoopTensor(const Tensor &t, const std::string &name) {
void CheckScaleTensorShape(const Tensor &t, const std::string &name) { void CheckScaleTensorShape(const Tensor &t, const std::string &name) {
NVTE_CHECK(t.scaling_mode != NVTE_INVALID_SCALING, "Invalid scaling mode!"); NVTE_CHECK(t.scaling_mode != NVTE_INVALID_SCALING, "Invalid scaling mode!");
if (is_tensor_scaling(t.scaling_mode)) { if (is_tensor_scaling(t.scaling_mode)) {
// per-tensor scaling if (is_fp8_dtype(t.dtype())) {
// FP8 tensor with tensor scaling
if (t.has_data()) { if (t.has_data()) {
NVTE_CHECK(t.scale_inv.numel() == 1, "Tensor \"", name, NVTE_CHECK(t.scale_inv.numel() == 1, "Tensor \"", name,
"\" has invalid scale_inv shape (expected (1), got ", t.scale_inv.shape, ")"); "\" has invalid scale_inv shape (expected 1 entry, got ", t.scale_inv.shape,
")");
} }
if (t.has_columnwise_data()) { if (t.has_columnwise_data()) {
NVTE_CHECK(t.columnwise_scale_inv.numel() == 1, "Tensor \"", name, NVTE_CHECK(t.columnwise_scale_inv.numel() == 1, "Tensor \"", name,
"\" has invalid columnwise_scale_inv shape (expected (1), got ", "\" has invalid columnwise_scale_inv shape (expected 1 entry, got ",
t.columnwise_scale_inv.shape, ")");
}
} else {
// High-precision tensor
if (t.has_data()) {
NVTE_CHECK(t.scale_inv.numel() == 0, "Tensor \"", name,
"\" has invalid scale_inv shape (expected 0 entries, got ", t.scale_inv.shape,
")");
}
if (t.has_columnwise_data()) {
NVTE_CHECK(t.columnwise_scale_inv.numel() == 0, "Tensor \"", name,
"\" has invalid columnwise_scale_inv shape (expected 0 entries, got ",
t.columnwise_scale_inv.shape, ")"); t.columnwise_scale_inv.shape, ")");
} }
}
} else { } else {
if (t.scaling_mode == NVTE_MXFP8_1D_SCALING) { if (t.scaling_mode == NVTE_MXFP8_1D_SCALING) {
// Need (4, 128) alignment even for e8 scaling factor // Need (4, 128) alignment even for e8 scaling factor
...@@ -159,7 +174,7 @@ void CheckInputTensor(const Tensor &t, const std::string &name) { ...@@ -159,7 +174,7 @@ void CheckInputTensor(const Tensor &t, const std::string &name) {
if (is_fp8_dtype(type)) { if (is_fp8_dtype(type)) {
// FP8 input needs to have scale_inv // FP8 input needs to have scale_inv
if (t.has_data()) { if (t.has_data()) {
NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP8 scaling factor input ", name, NVTE_CHECK(t.scale_inv.has_data(), "FP8 scaling factor input ", name,
"_scale_inverse must be allocated"); "_scale_inverse must be allocated");
NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32 || t.scale_inv.dtype == DType::kFloat8E8M0, NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32 || t.scale_inv.dtype == DType::kFloat8E8M0,
"FP8 scaling factor input ", name, "FP8 scaling factor input ", name,
...@@ -168,7 +183,7 @@ void CheckInputTensor(const Tensor &t, const std::string &name) { ...@@ -168,7 +183,7 @@ void CheckInputTensor(const Tensor &t, const std::string &name) {
to_string(t.scale_inv.dtype), ")"); to_string(t.scale_inv.dtype), ")");
} }
if (t.has_columnwise_data()) { if (t.has_columnwise_data()) {
NVTE_CHECK(t.columnwise_scale_inv.dptr != nullptr, "FP8 scaling factor input ", name, NVTE_CHECK(t.columnwise_scale_inv.has_data(), "FP8 scaling factor input ", name,
"_columnwise_scale_inverse must be allocated"); "_columnwise_scale_inverse must be allocated");
NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat32 || NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat32 ||
t.columnwise_scale_inv.dtype == DType::kFloat8E8M0, t.columnwise_scale_inv.dtype == DType::kFloat8E8M0,
...@@ -181,7 +196,7 @@ void CheckInputTensor(const Tensor &t, const std::string &name) { ...@@ -181,7 +196,7 @@ void CheckInputTensor(const Tensor &t, const std::string &name) {
// TODO(ksivaman): Fix this to check for amaxes and other details. // TODO(ksivaman): Fix this to check for amaxes and other details.
// For now only needed for swizzle. // For now only needed for swizzle.
if (t.has_data()) { if (t.has_data()) {
NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP4 scaling factor input ", name, NVTE_CHECK(t.scale_inv.has_data(), "FP4 scaling factor input ", name,
"_scale_inverse must be allocated"); "_scale_inverse must be allocated");
NVTE_CHECK(t.scale_inv.dtype == DType::kFloat8E4M3, "FP4 scaling factor input ", name, NVTE_CHECK(t.scale_inv.dtype == DType::kFloat8E4M3, "FP4 scaling factor input ", name,
"_scale_inverse has invalid dtype " "_scale_inverse has invalid dtype "
...@@ -189,7 +204,7 @@ void CheckInputTensor(const Tensor &t, const std::string &name) { ...@@ -189,7 +204,7 @@ void CheckInputTensor(const Tensor &t, const std::string &name) {
to_string(t.scale_inv.dtype), ")"); to_string(t.scale_inv.dtype), ")");
} }
if (t.has_columnwise_data()) { if (t.has_columnwise_data()) {
NVTE_CHECK(t.columnwise_scale_inv.dptr != nullptr, "FP4 scaling factor input ", name, NVTE_CHECK(t.columnwise_scale_inv.has_data(), "FP4 scaling factor input ", name,
"_columnwise_scale_inverse must be allocated"); "_columnwise_scale_inverse must be allocated");
NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat8E4M3, "FP8 scaling factor input ", NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat8E4M3, "FP8 scaling factor input ",
name, name,
...@@ -198,11 +213,10 @@ void CheckInputTensor(const Tensor &t, const std::string &name) { ...@@ -198,11 +213,10 @@ void CheckInputTensor(const Tensor &t, const std::string &name) {
to_string(t.columnwise_scale_inv.dtype), ")"); to_string(t.columnwise_scale_inv.dtype), ")");
} }
} else { } else {
NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 input ", name); NVTE_CHECK(!t.scale.has_data(), "Scale is not supported for non-FP8 input ", name);
NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 input ", name); NVTE_CHECK(!t.scale_inv.has_data(), "Scale_inv is not supported for non-FP8 input ", name);
NVTE_CHECK(t.scale_inv.dptr == nullptr, "Scale_inv is not supported for non-FP8 input ", name); NVTE_CHECK(!t.columnwise_scale_inv.has_data(), "Scale_inv is not supported for non-FP8 input ",
NVTE_CHECK(t.columnwise_scale_inv.dptr == nullptr, name);
"Scale_inv is not supported for non-FP8 input ", name);
} }
NVTE_CHECK(t.has_data() || t.has_columnwise_data(), "Input ", name, " is not allocated!"); NVTE_CHECK(t.has_data() || t.has_columnwise_data(), "Input ", name, " is not allocated!");
...@@ -213,14 +227,14 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt ...@@ -213,14 +227,14 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
const DType type = t.dtype(); const DType type = t.dtype();
if (is_fp8_dtype(type)) { if (is_fp8_dtype(type)) {
// FP8 output needs to have scale, scale_inv and (if delayed scaling) amax // FP8 output needs to have scale, scale_inv and (if delayed scaling) amax
if (t.scaling_mode == NVTE_DELAYED_TENSOR_SCALING && t.amax.dptr != nullptr) { if (t.scaling_mode == NVTE_DELAYED_TENSOR_SCALING && t.amax.has_data()) {
NVTE_CHECK(t.amax.dtype == DType::kFloat32, "Invalid amax dtype (expected ", NVTE_CHECK(t.amax.dtype == DType::kFloat32, "Invalid amax dtype (expected ",
to_string(DType::kFloat32), ", got ", to_string(t.amax.dtype), ")"); to_string(DType::kFloat32), ", got ", to_string(t.amax.dtype), ")");
NVTE_CHECK(product(t.amax.shape) == 1, "Invalid shape of amax in output ", name, NVTE_CHECK(t.amax.numel() == 1, "Invalid shape of amax in output ", name,
" (expected 1 entry, got shape=", t.amax.shape, ")"); " (expected 1 entry, got shape=", t.amax.shape, ")");
} }
if (t.has_data()) { if (t.has_data()) {
NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP8 scaling factor output ", name, NVTE_CHECK(t.scale_inv.has_data(), "FP8 scaling factor output ", name,
"_scale_inverse must be allocated"); "_scale_inverse must be allocated");
NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32 || t.scale_inv.dtype == DType::kFloat8E8M0, NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32 || t.scale_inv.dtype == DType::kFloat8E8M0,
"FP8 scaling factor output ", name, "FP8 scaling factor output ", name,
...@@ -229,7 +243,7 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt ...@@ -229,7 +243,7 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
to_string(t.scale_inv.dtype), ")"); to_string(t.scale_inv.dtype), ")");
} }
if (t.has_columnwise_data()) { if (t.has_columnwise_data()) {
NVTE_CHECK(t.columnwise_scale_inv.dptr != nullptr, "FP8 scaling factor output ", name, NVTE_CHECK(t.columnwise_scale_inv.has_data(), "FP8 scaling factor output ", name,
"_columnwise_scale_inverse must be allocated"); "_columnwise_scale_inverse must be allocated");
NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat32 || NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat32 ||
t.columnwise_scale_inv.dtype == DType::kFloat8E8M0, t.columnwise_scale_inv.dtype == DType::kFloat8E8M0,
...@@ -241,7 +255,7 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt ...@@ -241,7 +255,7 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
} else if (is_fp4_dtype(type)) { } else if (is_fp4_dtype(type)) {
// FP4 output needs to have the scale_inv // FP4 output needs to have the scale_inv
if (t.has_data()) { if (t.has_data()) {
NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP4 scaling factor output ", name, NVTE_CHECK(t.scale_inv.has_data(), "FP4 scaling factor output ", name,
"_scale_inverse must be allocated"); "_scale_inverse must be allocated");
NVTE_CHECK(t.scale_inv.dtype == DType::kFloat8E4M3, "FP4 scaling factor output ", name, NVTE_CHECK(t.scale_inv.dtype == DType::kFloat8E4M3, "FP4 scaling factor output ", name,
"_scale_inverse has invalid dtype " "_scale_inverse has invalid dtype "
...@@ -249,7 +263,7 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt ...@@ -249,7 +263,7 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
to_string(t.scale_inv.dtype), ")"); to_string(t.scale_inv.dtype), ")");
} }
if (t.has_columnwise_data()) { if (t.has_columnwise_data()) {
NVTE_CHECK(t.columnwise_scale_inv.dptr != nullptr, "FP4 scaling factor output ", name, NVTE_CHECK(t.columnwise_scale_inv.has_data(), "FP4 scaling factor output ", name,
"_columnwise_scale_inverse must be allocated"); "_columnwise_scale_inverse must be allocated");
NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat8E4M3, "FP4 scaling factor output ", NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat8E4M3, "FP4 scaling factor output ",
name, name,
...@@ -258,12 +272,10 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt ...@@ -258,12 +272,10 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
to_string(t.columnwise_scale_inv.dtype), ")"); to_string(t.columnwise_scale_inv.dtype), ")");
} }
} else { } else {
NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 output ", name); NVTE_CHECK(!t.scale.has_data(), "Scale is not supported for non-FP8 output ", name);
// Unfused quant with level 2 nvfp4 scaling will produce high precision tensors with amax. NVTE_CHECK(!t.scale_inv.has_data(), "Scale_inv is not supported for non-FP8 output ", name);
// NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 output ", name); NVTE_CHECK(!t.columnwise_scale_inv.has_data(), "Scale_inv is not supported for non-FP8 input ",
NVTE_CHECK(t.scale_inv.dptr == nullptr, "Scale_inv is not supported for non-FP8 output ", name); name);
NVTE_CHECK(t.columnwise_scale_inv.dptr == nullptr,
"Scale_inv is not supported for non-FP8 input ", name);
} }
if (!allow_empty) { if (!allow_empty) {
...@@ -622,7 +634,11 @@ NVTEShape nvte_make_shape(const size_t *data, size_t ndim) { ...@@ -622,7 +634,11 @@ NVTEShape nvte_make_shape(const size_t *data, size_t ndim) {
NVTE_CHECK(ndim <= sizeof(ret.data) / sizeof(ret.data[0]), NVTE_CHECK(ndim <= sizeof(ret.data) / sizeof(ret.data[0]),
"Too many dims for NVTEShape (requested: ", ndim, "Too many dims for NVTEShape (requested: ", ndim,
", max: ", sizeof(ret.data) / sizeof(ret.data[0]), ")"); ", max: ", sizeof(ret.data) / sizeof(ret.data[0]), ")");
if (data == nullptr) {
std::fill(ret.data, ret.data + ndim, 0);
} else {
std::copy(data, data + ndim, ret.data); std::copy(data, data + ndim, ret.data);
}
ret.ndim = ndim; ret.ndim = ndim;
return ret; return ret;
} }
...@@ -729,7 +745,7 @@ void *nvte_tensor_columnwise_scale_inv(const NVTETensor tensor) { ...@@ -729,7 +745,7 @@ void *nvte_tensor_columnwise_scale_inv(const NVTETensor tensor) {
NVTEShape nvte_tensor_scale_inv_shape(const NVTETensor tensor) { NVTEShape nvte_tensor_scale_inv_shape(const NVTETensor tensor) {
auto *t = transformer_engine::convertNVTETensor(tensor); auto *t = transformer_engine::convertNVTETensor(tensor);
if (t == nullptr) { if (t == nullptr) {
return nvte_make_shape(nullptr, 0); return nvte_make_shape(nullptr, 1);
} }
return nvte_make_shape(t->scale_inv.shape.data(), t->scale_inv.shape.size()); return nvte_make_shape(t->scale_inv.shape.data(), t->scale_inv.shape.size());
} }
...@@ -768,7 +784,7 @@ void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name, ...@@ -768,7 +784,7 @@ void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name,
NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam param_name) { NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam param_name) {
if (tensor == nullptr) { if (tensor == nullptr) {
return {nullptr, kNVTEFloat32, nvte_make_shape(nullptr, 0)}; return {nullptr, kNVTEFloat32, nvte_make_shape(nullptr, 1)};
} }
const auto &t = *transformer_engine::convertNVTETensorCheck(tensor); const auto &t = *transformer_engine::convertNVTETensorCheck(tensor);
switch (param_name) { switch (param_name) {
...@@ -813,14 +829,21 @@ void nvte_tensor_pack_destroy(NVTETensorPack *pack) { ...@@ -813,14 +829,21 @@ void nvte_tensor_pack_destroy(NVTETensorPack *pack) {
void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream) { void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream) {
if (tensor == nullptr) return; if (tensor == nullptr) return;
const auto &t = *transformer_engine::convertNVTETensorCheck(tensor); const auto &t = *transformer_engine::convertNVTETensorCheck(tensor);
// Zero out tensor data if allocated // Zero out tensor data if allocated
if (t.data.dptr != nullptr) { if (t.data.dptr != nullptr) {
const size_t size_in_bytes = nvte_tensor_size_bytes(tensor); const auto size = t.data.buffer_size_bytes();
NVTE_CHECK_CUDA(cudaMemsetAsync(t.data.dptr, 0, size_in_bytes, stream)); if (size > 0) {
NVTE_CHECK_CUDA(cudaMemsetAsync(t.data.dptr, 0, size, stream));
} }
// Set amax to 0 if allocated }
// Zero out amax if allocated
if (t.amax.dptr != nullptr) { if (t.amax.dptr != nullptr) {
NVTE_CHECK_CUDA(cudaMemsetAsync(t.amax.dptr, 0, sizeof(float), stream)); const auto size = t.amax.buffer_size_bytes();
if (size > 0) {
NVTE_CHECK_CUDA(cudaMemsetAsync(t.amax.dptr, 0, size, stream));
}
} }
} }
...@@ -1007,7 +1030,7 @@ void nvte_set_grouped_tensor_param(NVTEGroupedTensor *tensor, NVTEGroupedTensorP ...@@ -1007,7 +1030,7 @@ void nvte_set_grouped_tensor_param(NVTEGroupedTensor *tensor, NVTEGroupedTensorP
NVTEBasicTensor nvte_get_grouped_tensor_param(const NVTEGroupedTensor tensor, NVTEBasicTensor nvte_get_grouped_tensor_param(const NVTEGroupedTensor tensor,
NVTEGroupedTensorParam param_name) { NVTEGroupedTensorParam param_name) {
if (tensor == nullptr) { if (tensor == nullptr) {
return {nullptr, kNVTEFloat32, nvte_make_shape(nullptr, 0)}; return {nullptr, kNVTEFloat32, nvte_make_shape(nullptr, 1)};
} }
const auto &t = *transformer_engine::convertNVTEGroupedTensorCheck(tensor); const auto &t = *transformer_engine::convertNVTEGroupedTensorCheck(tensor);
...@@ -1059,7 +1082,7 @@ NVTEScalingMode nvte_grouped_tensor_scaling_mode(const NVTEGroupedTensor tensor) ...@@ -1059,7 +1082,7 @@ NVTEScalingMode nvte_grouped_tensor_scaling_mode(const NVTEGroupedTensor tensor)
NVTEShape nvte_get_grouped_tensor_logical_shape(const NVTEGroupedTensor tensor) { NVTEShape nvte_get_grouped_tensor_logical_shape(const NVTEGroupedTensor tensor) {
if (tensor == nullptr) { if (tensor == nullptr) {
return nvte_make_shape(nullptr, 0); return nvte_make_shape(nullptr, 1);
} }
const auto &t = *transformer_engine::convertNVTEGroupedTensorCheck(tensor); const auto &t = *transformer_engine::convertNVTEGroupedTensorCheck(tensor);
return t.logical_shape; return t.logical_shape;
......
...@@ -198,8 +198,6 @@ void populate_cast_transpose_dbias_workspace_config(const Tensor &cast_output, / ...@@ -198,8 +198,6 @@ void populate_cast_transpose_dbias_workspace_config(const Tensor &cast_output, /
workspace->data.dtype); workspace->data.dtype);
const size_t required_size = const size_t required_size =
get_buffer_size_bytes(num_rows_partial_dbias, row_length, DType::kFloat32); get_buffer_size_bytes(num_rows_partial_dbias, row_length, DType::kFloat32);
NVTE_CHECK(!workspace->data.shape.empty(), "Invalid workspace dims (expected (",
num_rows_partial_dbias, ",", row_length, "), found ())");
NVTE_CHECK(workspace_size >= required_size, "Invalid workspace (expected dims=(", NVTE_CHECK(workspace_size >= required_size, "Invalid workspace (expected dims=(",
num_rows_partial_dbias, ",", row_length, "), dtype=", to_string(DType::kFloat32), num_rows_partial_dbias, ",", row_length, "), dtype=", to_string(DType::kFloat32),
"; found dims=", workspace->data.shape, "; found dims=", workspace->data.shape,
......
...@@ -388,8 +388,6 @@ void populate_transpose_dbias_workspace_config(const Tensor &input, /*cast*/ ...@@ -388,8 +388,6 @@ void populate_transpose_dbias_workspace_config(const Tensor &input, /*cast*/
workspace->data.dtype); workspace->data.dtype);
const size_t required_size = const size_t required_size =
get_buffer_size_bytes(num_rows_partial_dbias, row_length, DType::kFloat32); get_buffer_size_bytes(num_rows_partial_dbias, row_length, DType::kFloat32);
NVTE_CHECK(!workspace->data.shape.empty(), "Invalid workspace dims (expected (",
num_rows_partial_dbias, ",", row_length, "), found ())");
NVTE_CHECK(workspace_size >= required_size, "Invalid workspace (expected dims=(", NVTE_CHECK(workspace_size >= required_size, "Invalid workspace (expected dims=(",
num_rows_partial_dbias, ",", row_length, "), dtype=", to_string(DType::kFloat32), num_rows_partial_dbias, ",", row_length, "), dtype=", to_string(DType::kFloat32),
"; found dims=", workspace->data.shape, "; found dims=", workspace->data.shape,
......
...@@ -334,12 +334,12 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_fp ...@@ -334,12 +334,12 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_fp
tensor_cpp_list.emplace_back(makeTransformerEngineTensor( tensor_cpp_list.emplace_back(makeTransformerEngineTensor(
rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr, rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr,
columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr,
rowwise_usage ? rowwise_data_shapes[i] : std::vector<size_t>{}, rowwise_usage ? rowwise_data_shapes[i] : std::vector<size_t>{0},
columnwise_usage ? columnwise_data_shapes[i] : std::vector<size_t>{}, fp8_dtype, nullptr, columnwise_usage ? columnwise_data_shapes[i] : std::vector<size_t>{0}, fp8_dtype, nullptr,
nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr, nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr,
columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr,
rowwise_usage ? rowwise_scale_shapes[i] : std::vector<size_t>{}, rowwise_usage ? rowwise_scale_shapes[i] : std::vector<size_t>{0},
columnwise_usage ? columnwise_scale_shapes[i] : std::vector<size_t>{}, scaling_mode)); columnwise_usage ? columnwise_scale_shapes[i] : std::vector<size_t>{0}, scaling_mode));
} }
return retval; return retval;
...@@ -481,12 +481,12 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx ...@@ -481,12 +481,12 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx
tensor_cpp_list.emplace_back(makeTransformerEngineTensor( tensor_cpp_list.emplace_back(makeTransformerEngineTensor(
rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr, rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr,
columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr,
rowwise_usage ? rowwise_data_shapes[i] : std::vector<size_t>{}, rowwise_usage ? rowwise_data_shapes[i] : std::vector<size_t>{0},
columnwise_usage ? columnwise_data_shapes[i] : std::vector<size_t>{}, fp8_dtype, nullptr, columnwise_usage ? columnwise_data_shapes[i] : std::vector<size_t>{0}, fp8_dtype, nullptr,
nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr, nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr,
columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr,
rowwise_usage ? rowwise_scale_shapes[i] : std::vector<size_t>{}, rowwise_usage ? rowwise_scale_shapes[i] : std::vector<size_t>{0},
columnwise_usage ? columnwise_scale_shapes[i] : std::vector<size_t>{}, scaling_mode)); columnwise_usage ? columnwise_scale_shapes[i] : std::vector<size_t>{0}, scaling_mode));
} }
return retval; return retval;
...@@ -685,13 +685,13 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>, bool> bulk_alloc ...@@ -685,13 +685,13 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>, bool> bulk_alloc
auto tensor_wrapper = makeTransformerEngineTensor( auto tensor_wrapper = makeTransformerEngineTensor(
rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr, rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr,
columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr,
rowwise_usage ? rowwise_data_shapes[i] : std::vector<size_t>{}, rowwise_usage ? rowwise_data_shapes[i] : std::vector<size_t>{0},
columnwise_usage ? columnwise_data_shapes[i] : std::vector<size_t>{}, fp4_dtype, columnwise_usage ? columnwise_data_shapes[i] : std::vector<size_t>{0}, fp4_dtype,
/*amax_ptr=*/nullptr, /*amax_ptr=*/nullptr,
/*scale_ptr=*/nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr, /*scale_ptr=*/nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr,
columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr, columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr,
rowwise_usage ? rowwise_scale_shapes[i] : std::vector<size_t>{}, rowwise_usage ? rowwise_scale_shapes[i] : std::vector<size_t>{0},
columnwise_usage ? columnwise_scale_shapes[i] : std::vector<size_t>{}, scaling_mode); columnwise_usage ? columnwise_scale_shapes[i] : std::vector<size_t>{0}, scaling_mode);
// Set the amax rowwise and amax columnwise if available // Set the amax rowwise and amax columnwise if available
if (rowwise_usage) { if (rowwise_usage) {
......
...@@ -43,10 +43,10 @@ bool is_low_precision(const DType type) { ...@@ -43,10 +43,10 @@ bool is_low_precision(const DType type) {
std::vector<size_t> getGemmOutputShape(const NVTEShape& A_shape, const bool transa, std::vector<size_t> getGemmOutputShape(const NVTEShape& A_shape, const bool transa,
const NVTEShape& B_shape, const bool transb) { const NVTEShape& B_shape, const bool transb) {
// Flatten outer dims to get 2D matrices // Flatten outer dims to get 2D matrices
const size_t A0 = product(A_shape, 0, A_shape.ndim - 1); const size_t A0 = A_shape.ndim > 0 ? product(A_shape, 0, A_shape.ndim - 1) : 1;
const size_t A1 = A_shape.data[A_shape.ndim - 1]; const size_t A1 = A_shape.ndim > 0 ? A_shape.data[A_shape.ndim - 1] : 1;
const size_t B0 = product(B_shape, 0, B_shape.ndim - 1); const size_t B0 = B_shape.ndim > 0 ? product(B_shape, 0, B_shape.ndim - 1) : 1;
const size_t B1 = B_shape.data[B_shape.ndim - 1]; const size_t B1 = B_shape.ndim > 0 ? B_shape.data[B_shape.ndim - 1] : 1;
// Check matrix dims // Check matrix dims
NVTE_CHECK((transa ? A1 : A0) == (transb ? B0 : B1), "Invalid matrix dimensions for GEMM (A=(", NVTE_CHECK((transa ? A1 : A0) == (transb ? B0 : B1), "Invalid matrix dimensions for GEMM (A=(",
......
...@@ -22,8 +22,8 @@ void compute_amax(const at::Tensor& tensor, at::Tensor& amax) { ...@@ -22,8 +22,8 @@ void compute_amax(const at::Tensor& tensor, at::Tensor& amax) {
TORCH_CHECK(amax.numel() == 1, "amax must have exactly one element"); TORCH_CHECK(amax.numel() == 1, "amax must have exactly one element");
auto* amax_ptr = amax.data_ptr<float>(); auto* amax_ptr = amax.data_ptr<float>();
TensorWrapper fake_te_output( TensorWrapper fake_te_output(
nullptr, te_input.shape(), /*dptr=*/nullptr, te_input.shape(),
DType::kFloat8E4M3, // It doesn't matter because we only compute amax. DType::kFloat32, // It doesn't matter because we only compute amax.
amax_ptr); amax_ptr);
nvte_compute_amax(te_input.data(), fake_te_output.data(), at::cuda::getCurrentCUDAStream()); nvte_compute_amax(te_input.data(), fake_te_output.data(), at::cuda::getCurrentCUDAStream());
......
...@@ -142,13 +142,20 @@ std::optional<at::Tensor> multi_tensor_swizzle_scaling_factors( ...@@ -142,13 +142,20 @@ std::optional<at::Tensor> multi_tensor_swizzle_scaling_factors(
auto& tensor = tensors[i]; auto& tensor = tensors[i];
void* scale_inv_dptr = scale_inv_dptrs[i]; void* scale_inv_dptr = scale_inv_dptrs[i];
void* swizzled_scale_inv_dptr = getDataPtr(buffer, scale_inv_offsets[i]); void* swizzled_scale_inv_dptr = getDataPtr(buffer, scale_inv_offsets[i]);
// auto input_shape = nvte_shape_to_vector(tensor.shape());
// Empty tensors don't require scale swizzling
if (tensor.numel() == 0) {
continue;
}
// Tensor shape
NVTEShape nvte_input_shape; NVTEShape nvte_input_shape;
if (rowwise) { if (rowwise) {
nvte_input_shape = tensor.shape(); nvte_input_shape = tensor.shape();
} else { } else {
nvte_input_shape = tensor.get_columnwise_data().shape; nvte_input_shape = tensor.get_columnwise_data().shape;
} }
auto input_shape = nvte_shape_to_vector(nvte_input_shape); auto input_shape = nvte_shape_to_vector(nvte_input_shape);
// Reconstruct input only to avoid swizzling both directions if not needed. // Reconstruct input only to avoid swizzling both directions if not needed.
// Use any 8 bit type, it's irrelevant. // Use any 8 bit type, it's irrelevant.
...@@ -202,14 +209,14 @@ at::Tensor convert_block_scaling_to_mxfp8_tensor(transformer_engine::TensorWrapp ...@@ -202,14 +209,14 @@ at::Tensor convert_block_scaling_to_mxfp8_tensor(transformer_engine::TensorWrapp
size_t data_flat_last_dim = 1; size_t data_flat_last_dim = 1;
if (rowwise) { if (rowwise) {
data = input.get_rowwise_data(); data = input.get_rowwise_data();
for (int i = 0; i < data.shape.ndim - 1; ++i) { for (size_t i = 0; i < data.shape.ndim - 1; ++i) {
data_flat_first_dim *= data.shape.data[i]; data_flat_first_dim *= data.shape.data[i];
} }
data_flat_last_dim = data.shape.data[data.shape.ndim - 1]; data_flat_last_dim = data.shape.data[data.shape.ndim - 1];
} else { } else {
data = input.get_columnwise_data(); data = input.get_columnwise_data();
data_flat_first_dim = data.shape.data[0]; data_flat_first_dim = data.shape.data[0];
for (int i = 1; i < data.shape.ndim; ++i) { for (size_t i = 1; i < data.shape.ndim; ++i) {
data_flat_last_dim *= data.shape.data[i]; data_flat_last_dim *= data.shape.data[i];
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment