Unverified Commit 61822061 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[Core] Fix inconsistent logic in C++ tensor class (#2330)



* Initialize empty tensors with shape=[0] instead of shape=[].
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix runtime crash in LayerNorm

Still seeing correctness issues.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Make sure norm workspace sizes are not zero
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove assumption in swizzle kernel that data is available.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove assumption in multi-swizzle kernel that data is available.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Remove unnecessary explicit call to default constructor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Avoid accessing tensor data pointer if tensor has no entries
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Apply suggestions from code review
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update transformer_engine/common/swizzle/swizzle.cu
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Review suggestions from @ptrendx and @greptile-apps
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Prefer using row-wise/col-wise shape based on which has data
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix merge conflict, expand docs, fix inconsistency in dim function
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Change Tensor::has_data to check whether tensor is initialized, not whether pointer is valid.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Review suggestion from @greptile-apps
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Debug incorrect tensor initialization in tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Clarify comments that has_data does not guarantee safe pointer accesses
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug test failure when computing amaxes
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
parent d126cdd6
......@@ -278,52 +278,33 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
Tensor::Tensor(const std::string& name,
const NVTEShape &shape, const DType type,
const bool rowwise, const bool columnwise,
const NVTEScalingMode &scaling_mode) {
name_ = name;
const NVTEScalingMode &scaling_mode)
: tensor_(scaling_mode), rowwise_{rowwise}, columnwise_{columnwise}, name_{name} {
// Initialize RNG
const size_t seed = create_seed_from_tensor_name(name);
gen_.seed(seed);
rowwise_ = rowwise;
columnwise_ = columnwise;
size_t total_size = bytes(shape, type);
void *dptr_rowwise = nullptr;
void *dptr_columnwise = nullptr;
cpu_data_rowwise_ = nullptr;
cpu_data_columnwise_ = nullptr;
amax_cpu_data_ = nullptr;
scale_cpu_data_ = nullptr;
rowwise_scale_inv_cpu_data_ = nullptr;
columnwise_scale_inv_cpu_data_ = nullptr;
float *amax = nullptr, *scale = nullptr;
float *rowwise_scale_inv = nullptr, *columnwise_scale_inv = nullptr;
// Make sure shape is valid
if (columnwise) {
NVTE_CHECK(shape.ndim >= 2);
}
std::vector<size_t> normalized_shape_v = {product(shape, 0, shape.ndim - 1),
shape.data[shape.ndim - 1]};
NVTEShape normalized_shape = convertShape(normalized_shape_v);
NVTEShape columnwise_shape = {};
std::vector<size_t> columnwise_shape_vec;
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING
|| scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) {
// Transpose when tensor scaling
columnwise_shape_vec.emplace_back(shape.data[shape.ndim - 1]);
for (size_t i = 0; i < shape.ndim - 1; ++i) {
columnwise_shape_vec.emplace_back(shape.data[i]);
}
// Shape after flattening to 2D
NVTEShape flattened_shape;
{
std::vector<size_t> flattened_shape_vec;
if (shape.ndim > 0) {
flattened_shape_vec.push_back(product(shape, 0, shape.ndim - 1));
flattened_shape_vec.push_back(shape.data[shape.ndim - 1]);
} else {
// Same shape for MX and NVFP4
for (size_t i = 0; i < shape.ndim; ++i) {
columnwise_shape_vec.emplace_back(shape.data[i]);
flattened_shape_vec.resize(2, 1);
}
flattened_shape = convertShape(flattened_shape_vec);
}
if (columnwise) {
columnwise_shape = nvte_make_shape(columnwise_shape_vec.data(), columnwise_shape_vec.size());
}
tensor_ = TensorWrapper(scaling_mode);
// Allocate and initialize data
void *dptr_rowwise = nullptr, *dptr_columnwise = nullptr;
const size_t total_size = bytes(shape, type);
if (total_size != 0) {
if (rowwise) {
cudaMalloc((void**)&dptr_rowwise, total_size); // NOLINT(*)
......@@ -339,11 +320,51 @@ Tensor::Tensor(const std::string& name,
}
}
// Set tensor row-wise data
if (rowwise) {
const DType rowwise_type = (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat4E2M1 : type;
const DType colwise_type = (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat4E2M1 : type;
tensor_.set_rowwise_data(dptr_rowwise, rowwise_type, shape);
}
// Set tensor column-wise data
if (columnwise) {
// Determine shape of column-wise data
std::vector<size_t> columnwise_shape_vec;
switch (scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING:
case NVTE_BLOCK_SCALING_1D:
case NVTE_BLOCK_SCALING_2D: {
// Column-wise data shape is transposed
if (shape.ndim > 0) {
columnwise_shape_vec.emplace_back(shape.data[shape.ndim - 1]);
for (size_t i = 0; i < shape.ndim - 1; ++i) {
columnwise_shape_vec.emplace_back(shape.data[i]);
}
}
break;
}
case NVTE_MXFP8_1D_SCALING:
case NVTE_NVFP4_1D_SCALING: {
// Column-wise data matches shape
for (size_t i = 0; i < shape.ndim; ++i) {
columnwise_shape_vec.emplace_back(shape.data[i]);
}
break;
}
default:
NVTE_ERROR("Unrecognized scaling mode (", (size_t)scaling_mode, ").");
}
const auto columnwise_shape = nvte_make_shape(columnwise_shape_vec.data(),
columnwise_shape_vec.size());
// Set column-wise data buffer
const DType colwise_type = (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat4E2M1 : type;
tensor_.set_columnwise_data(dptr_columnwise, colwise_type, columnwise_shape);
}
// Configure scales, amaxes, and other tensor buffers
float *amax = nullptr, *scale = nullptr;
float *rowwise_scale_inv = nullptr, *columnwise_scale_inv = nullptr;
if (isFp8Type(type) || isFp4Type(type)) {
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
cudaMalloc((void**)&amax, sizeof(float)); // NOLINT(*)
......@@ -375,7 +396,7 @@ Tensor::Tensor(const std::string& name,
scale_cpu_data_ = std::make_shared<float>(0);
tensor_.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
}
auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(normalized_shape, tensor_.scaling_mode());
auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(flattened_shape, tensor_.scaling_mode());
auto rowwise_scale_size = rowwise_scale_meta.bytes();
auto columnwise_scale_size = colwise_scale_meta.bytes();
auto scale_shape = rowwise_scale_meta.shape;
......
......@@ -74,38 +74,49 @@ inline size_t product(const std::vector<size_t> &shape) {
return ret;
}
size_t get_buffer_size_bytes(const size_t N, const DType buffer_dtype);
size_t get_buffer_size_bytes(const size_t dim_first, const size_t dim_last,
const DType buffer_dtype);
struct SimpleTensor {
void *dptr;
std::vector<size_t> shape;
DType dtype;
SimpleTensor(void *dptr, const std::vector<size_t> &shape, DType dtype)
: dptr(dptr), shape(shape), dtype(dtype) {}
SimpleTensor(void *dptr, std::vector<size_t> shape, DType dtype)
: dptr{dptr}, shape{std::move(shape)}, dtype{dtype} {}
SimpleTensor(const NVTEBasicTensor &tensor) // NOLINT
: dptr(tensor.data_ptr),
shape(tensor.shape.data, tensor.shape.data + tensor.shape.ndim),
dtype(static_cast<DType>(tensor.dtype)) {}
SimpleTensor() : SimpleTensor(nullptr, {}, DType::kFloat32) {}
SimpleTensor() : SimpleTensor(nullptr, std::vector<size_t>{0}, DType::kFloat32) {}
operator NVTEBasicTensor() const {
return {dptr, static_cast<NVTEDType>(dtype),
nvte_make_shape(this->shape.data(), this->shape.size())};
}
size_t numel() const {
size_t acc = 1;
for (const auto &dim : shape) {
acc *= dim;
}
return acc;
}
bool has_data() const noexcept { return dptr != nullptr && numel() > 0; }
/*! Number of tensor elements. */
size_t numel() const { return product(shape); }
/*! Whether the tensor is initialized.
*
* Tensors with non-trivial shapes are considered initialized. This
* means that there is no guarantee that the data pointer can be
* safely accessed.
*/
bool has_data() const { return !(dptr == nullptr && shape.size() == 1 && shape[0] == 0); }
/*! Buffer size in bytes. */
size_t buffer_size_bytes() const { return get_buffer_size_bytes(numel(), dtype); }
/*! Reset to uninitialized tensor. */
void clear() {
dptr = nullptr;
shape.resize(0);
shape.resize(1);
shape[0] = 0;
dtype = DType::kFloat32;
}
};
......@@ -123,17 +134,9 @@ struct Tensor {
NVTEScalingMode scaling_mode;
NVTETensor nvte_tensor;
Tensor()
: data(),
columnwise_data(),
amax(nullptr, {1}, DType::kFloat32),
columnwise_amax(nullptr, {1}, DType::kFloat32),
scale(nullptr, {1}, DType::kFloat32),
scale_inv(nullptr, {1}, DType::kFloat32),
columnwise_scale_inv(nullptr, {1}, DType::kFloat32),
scaling_mode(NVTE_DELAYED_TENSOR_SCALING),
nvte_tensor(0) {}
Tensor() : scaling_mode{NVTE_DELAYED_TENSOR_SCALING}, nvte_tensor{0} {}
/*! Reset tensor data. */
void clear() {
data.clear();
columnwise_data.clear();
......@@ -147,65 +150,62 @@ struct Tensor {
explicit operator NVTETensor() const noexcept { return nvte_tensor; }
/*! Number of tensor elements. */
size_t numel() const {
size_t acc = 1;
for (const auto dim : shape()) {
acc *= dim;
if (!has_data() && has_columnwise_data()) {
return product(columnwise_data.shape);
}
return acc;
return product(data.shape);
}
// TODO(Tim): Change this to use data.has_data()
bool has_data() const noexcept { return data.dptr != nullptr; }
/*! Whether the tensor data buffer is not uninitialized.
*
* Buffers with non-trivial shapes are considered initialized. This
* means that there is no guarantee that the data pointer can be
* safely accessed.
*/
bool has_data() const { return data.has_data(); }
// Check for size (not just pointer) for 0-dim or no token cases.
// TODO(Tim): Change this to use columnwise_data.has_data()
bool has_columnwise_data() const noexcept {
return columnwise_data.dptr != nullptr || columnwise_data.shape.size() != 0;
}
/*! Whether the tensor column-wise data buffer is not uninitialized.
*
* Buffers with non-trivial shapes are considered initialized. This
* means that there is no guarantee that the data pointer can be
* safely accessed.
*/
bool has_columnwise_data() const { return columnwise_data.has_data(); }
/*! Datatype of tensor elements. */
DType dtype() const {
if (has_data()) return data.dtype;
if (has_columnwise_data()) return columnwise_data.dtype;
// Fallback, used e.g. in workspace
if (!has_data() && has_columnwise_data()) {
return columnwise_data.dtype;
}
return data.dtype;
}
/*! Number of tensor dimensions. */
size_t dim() const {
if (!has_data() && has_columnwise_data()) {
return columnwise_data.shape.size();
} else {
return data.shape.size();
}
return data.shape.size();
}
std::vector<size_t> shape() const {
/* Note: We sometimes experience spurious compiler errors
* (-Wstringop-overflow) from this function. It appears that GCC
* has some bugs with std::vector (see
* https://gcc.gnu.org/bugzilla/show_bug.cgi?id=109569).
/*! Tensor dimensions.
*
* This is the logical tensor shape. The underlying data may have a
* different shape, e.g. the column-wise data for some tensor
* formats are transposed.
*/
std::vector<size_t> shape() const {
// Each tensor format interprets its data differently
switch (scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING:
case NVTE_BLOCK_SCALING_1D:
case NVTE_BLOCK_SCALING_2D:
case NVTE_NVFP4_1D_SCALING: {
// Choose data buffer based on whether it is initialized
// Note: Uninitialized buffers currently have shape=[].
// However, this is logically incorrect. 0-D tensors have 1
// entry, and uninitialized tensors should have shape=[0].
bool use_columnwise_shape = false;
if (data.dptr != nullptr) {
use_columnwise_shape = false;
} else if (columnwise_data.dptr != nullptr) {
use_columnwise_shape = true;
} else if (data.shape.size() != 0) {
use_columnwise_shape = false;
} else if (columnwise_data.shape.size() != 0) {
use_columnwise_shape = true;
}
// Infer shape based on data
if (use_columnwise_shape) {
// Column-wise data is transposed
// Row-wise data shape matches tensor logical shape,
// column-wise data shape is transpose of logical shape
if (!has_data() && has_columnwise_data()) {
std::vector<size_t> ret;
if (!columnwise_data.shape.empty()) {
ret.reserve(columnwise_data.shape.size());
......@@ -218,38 +218,16 @@ struct Tensor {
}
return data.shape;
}
case NVTE_MXFP8_1D_SCALING:
case NVTE_MXFP8_1D_SCALING: {
// Row-wise and column-wise data shapes both match tensor
// logical shape
if (!has_data() && has_columnwise_data()) {
return columnwise_data.shape;
} else {
return data.shape;
}
break;
case NVTE_BLOCK_SCALING_1D:
case NVTE_BLOCK_SCALING_2D: {
if (!has_data() && has_columnwise_data()) {
std::vector<size_t> shape;
size_t ndim = columnwise_data.shape.size();
shape.reserve(ndim);
for (size_t i = 0; i + 1 < ndim; ++i) {
shape.push_back(columnwise_data.shape[i + 1]);
}
if (ndim > 0) {
shape.push_back(columnwise_data.shape[0]);
}
return shape;
} else {
// NOTE: We may have removed the data pointer from
// data by setting usage. In that case, we return
// the non-null shape. It is our best guess at the most
// recent shape.
return data.shape;
}
break;
}
default:
NVTE_ERROR("Cannot parse tensor shape with scaling mode \"", to_string(scaling_mode), "\"");
return {};
}
}
......@@ -347,10 +325,10 @@ struct GroupedTensor {
columnwise_amax(),
scale(),
num_tensors(num_tensors),
first_dims(nullptr, {}, DType::kInt64),
last_dims(nullptr, {}, DType::kInt64),
tensor_offsets(nullptr, {}, DType::kInt64),
logical_shape(nvte_make_shape(nullptr, 0)),
first_dims(nullptr, std::vector<size_t>{0}, DType::kInt64),
last_dims(nullptr, std::vector<size_t>{0}, DType::kInt64),
tensor_offsets(nullptr, std::vector<size_t>{0}, DType::kInt64),
logical_shape(nvte_make_shape(nullptr, 1)),
scaling_mode(scaling_mode),
nvte_tensor(0) {}
......@@ -383,9 +361,9 @@ struct GroupedTensor {
}
DType dtype() const {
if (has_data()) return data.dtype;
if (has_columnwise_data()) return columnwise_data.dtype;
// Fallback, used e.g. in workspace or when allow_empty=true
if (!has_data() && has_columnwise_data()) {
return columnwise_data.dtype;
}
return data.dtype;
}
......@@ -400,7 +378,7 @@ struct GroupedTensor {
first_dims.clear();
last_dims.clear();
tensor_offsets.clear();
logical_shape = nvte_make_shape(nullptr, 0);
logical_shape = nvte_make_shape(nullptr, 1);
num_tensors = 0;
scaling_mode = NVTE_DELAYED_TENSOR_SCALING;
nvte_tensor = 0;
......@@ -869,10 +847,6 @@ inline bool is_aligned_tensor_data(const Tensor &t, size_t alignment) {
size_t typeToSize(const DType type);
size_t typeToNumBits(const DType type);
size_t get_buffer_size_bytes(const size_t N, const DType buffer_dtype);
size_t get_buffer_size_bytes(const size_t dim_first, const size_t dim_last,
const DType buffer_dtype);
void CheckNoopTensor(const Tensor &t, const std::string &name);
void CheckInputTensor(const Tensor &t, const std::string &name);
void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty = false);
......
......@@ -142,7 +142,8 @@ void *nvte_tensor_columnwise_data(const NVTETensor tensor);
/*! \brief Construct a shape from an array of dimension sizes.
*
* \param[data] Pointer to start of shape array.
* \param[data] Pointer to start of shape array. If NULL, the shape
* will be filled with zeros.
* \param[data] Number of dimensions (must be <= 14)
*
* \return A shape. The shape will own its own copy of the data.
......@@ -575,15 +576,22 @@ class TensorWrapper {
*/
TensorWrapper(void *dptr, const NVTEShape &shape, const DType dtype, float *amax_dptr = nullptr,
float *scale_dptr = nullptr, float *scale_inv_dptr = nullptr,
const NVTEShape scale_inv_shape = defaultShape,
NVTEShape scale_inv_shape = defaultShape,
const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) {
tensor_ = nvte_create_tensor(scaling_mode);
NVTEBasicTensor data = {dptr, static_cast<NVTEDType>(dtype), shape};
nvte_set_tensor_param(&tensor_, kNVTERowwiseData, &data);
NVTEBasicTensor amax = {amax_dptr, kNVTEFloat32, defaultShape};
NVTEBasicTensor amax = {amax_dptr, kNVTEFloat32,
amax_dptr != nullptr ? defaultShape : emptyShape};
nvte_set_tensor_param(&tensor_, kNVTEAmax, &amax);
NVTEBasicTensor scale = {scale_dptr, kNVTEFloat32, defaultShape};
NVTEBasicTensor scale = {scale_dptr, kNVTEFloat32,
scale_dptr != nullptr ? defaultShape : emptyShape};
nvte_set_tensor_param(&tensor_, kNVTEScale, &scale);
if (scale_inv_dptr == nullptr && scale_inv_shape.ndim == defaultShape.ndim &&
scale_inv_shape.ndim == 1 && scale_inv_shape.data[0] == defaultShape.data[0]) {
// Scale-inv pointer has not been provided and shape matches default
scale_inv_shape = emptyShape;
}
NVTEBasicTensor scale_inv = {scale_inv_dptr, kNVTEFloat32, scale_inv_shape};
nvte_set_tensor_param(&tensor_, kNVTERowwiseScaleInv, &scale_inv);
}
......@@ -734,7 +742,7 @@ class TensorWrapper {
*/
const NVTEShape shape() const noexcept {
if (tensor_ == nullptr) {
return nvte_make_shape(nullptr, 0);
return emptyShape;
}
return nvte_tensor_shape(tensor_);
}
......@@ -745,7 +753,7 @@ class TensorWrapper {
*/
const NVTEShape columnwise_shape() const noexcept {
if (tensor_ == nullptr) {
return nvte_make_shape(nullptr, 0);
return emptyShape;
}
return nvte_tensor_columnwise_shape(tensor_);
}
......@@ -869,7 +877,7 @@ class TensorWrapper {
*/
const NVTEShape scale_inv_shape() const noexcept {
if (tensor_ == nullptr) {
return nvte_make_shape(nullptr, 0);
return emptyShape;
}
return nvte_tensor_scale_inv_shape(tensor_);
}
......@@ -888,6 +896,7 @@ class TensorWrapper {
static constexpr size_t defaultData = 1;
static constexpr NVTEShape defaultShape = {
{defaultData, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1};
static constexpr NVTEShape emptyShape = {{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1};
private:
NVTEShape convertShape(const NVTEShape &s) { return s; }
......
......@@ -127,7 +127,13 @@ void TeNormalizationPlan<KernelParamsType>::_build() {
template <typename KernelParamsType>
std::vector<size_t> TeNormalizationPlan<KernelParamsType>::getWorkspaceShape() const {
return {_launch_params.getTotalWorkspaceBytes(_is_layernorm)};
size_t workspace_size = _launch_params.getTotalWorkspaceBytes(_is_layernorm);
if (workspace_size == 0) {
// Workspace size must not be zero since that corresponds to a
// workspace size query
workspace_size = 1;
}
return {workspace_size};
}
template <typename KernelParamsType>
......@@ -405,7 +411,13 @@ void CudnnNormalizationPlan::_build() {
}
std::vector<size_t> CudnnNormalizationPlan::getWorkspaceShape() const {
return {static_cast<size_t>(_graph.get_workspace_size())};
size_t workspace_size = _graph.get_workspace_size();
if (workspace_size == 0) {
// Workspace size must not be zero since that corresponds to a
// workspace size query
workspace_size = 1;
}
return {workspace_size};
}
void CudnnNormalizationPlan::execute(Tensor* z, void* x_dptr, void* gamma_dptr, void* beta_dptr,
......
......@@ -51,7 +51,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
"RSigma must be 1D tensor with shape (x.shape[0],).");
NVTE_CHECK(rsigma->data.dtype == DType::kFloat32, "RSigma must be a float32 tensor.");
if (!workspace->data.shape.empty()) {
if (workspace->data.numel() != 0) {
CheckInputTensor(x, "x");
CheckInputTensor(gamma, "gamma");
CheckInputTensor(beta, "beta");
......@@ -94,7 +94,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
multiprocessorCount, zero_centered_gamma, is_aligned, z->scaling_mode, training,
gamma_in_weight_dtype);
if (workspace->data.shape.empty()) {
if (workspace->data.numel() == 0) {
workspace->data.shape = plan->getWorkspaceShape();
workspace->data.dtype = DType::kByte;
return;
......@@ -146,7 +146,7 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te
NVTE_CHECK(dbeta->data.shape == gamma.data.shape);
NVTE_CHECK(dbeta->data.dtype == gamma.data.dtype);
if (!workspace->data.shape.empty()) {
if (workspace->data.numel() != 0) {
CheckInputTensor(dz, "dz");
CheckInputTensor(x, "x");
CheckInputTensor(mu, "mu");
......@@ -179,7 +179,7 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te
multiprocessorCount, zero_centered_gamma, is_aligned, NVTE_DELAYED_TENSOR_SCALING, true,
gamma_in_weight_dtype);
if (workspace->data.shape.empty()) {
if (workspace->data.numel() == 0) {
workspace->data.shape = plan->getWorkspaceShape();
workspace->data.dtype = DType::kByte;
return;
......
......@@ -39,7 +39,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
"RSigma must be 1D tensor with shape (x.shape[0],).");
NVTE_CHECK(rsigma->data.dtype == DType::kFloat32, "RSigma must be a float32 tensor.");
if (!workspace->data.shape.empty()) {
if (workspace->data.numel() != 0) {
CheckInputTensor(x, "x");
CheckInputTensor(gamma, "gamma");
......@@ -79,7 +79,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
multiprocessorCount, zero_centered_gamma, is_aligned, z->scaling_mode, training,
gamma_in_weight_dtype);
if (workspace->data.shape.empty()) {
if (workspace->data.numel() == 0) {
workspace->data.shape = plan->getWorkspaceShape();
workspace->data.dtype = DType::kByte;
return;
......@@ -125,7 +125,7 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const
NVTE_CHECK(dgamma->data.shape == gamma.data.shape);
NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype);
if (!workspace->data.shape.empty()) {
if (workspace->data.numel() != 0) {
CheckInputTensor(dz, "dz");
CheckInputTensor(x, "x");
CheckInputTensor(rsigma, "rsigma");
......@@ -156,7 +156,7 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const
multiprocessorCount, zero_centered_gamma, is_aligned, NVTE_DELAYED_TENSOR_SCALING, true,
gamma_in_weight_dtype);
if (workspace->data.shape.empty()) {
if (workspace->data.numel() == 0) {
workspace->data.shape = plan->getWorkspaceShape();
workspace->data.dtype = DType::kByte;
return;
......@@ -191,7 +191,7 @@ void rmsnorm_bwd_add(const Tensor &dz, const Tensor &x, const Tensor &add, const
NVTE_CHECK(dgamma->data.shape == gamma.data.shape);
NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype);
if (!workspace->data.shape.empty()) {
if (workspace->data.numel() != 0) {
CheckInputTensor(dz, "dz");
CheckInputTensor(x, "x");
CheckInputTensor(add, "add");
......@@ -222,7 +222,7 @@ void rmsnorm_bwd_add(const Tensor &dz, const Tensor &x, const Tensor &add, const
multiprocessorCount, zero_centered_gamma, is_aligned, NVTE_DELAYED_TENSOR_SCALING, true,
gamma_in_weight_dtype);
if (workspace->data.shape.empty()) {
if (workspace->data.numel() == 0) {
workspace->data.shape = plan->getWorkspaceShape();
workspace->data.dtype = DType::kByte;
return;
......
......@@ -332,68 +332,118 @@ __global__ void multi_tensor_swizzle_col_scaling_kernel(MultiSwizzleArgs kernel_
} // namespace
void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t stream) {
NVTE_CHECK(
input->scaling_mode == NVTE_MXFP8_1D_SCALING || input->scaling_mode == NVTE_NVFP4_1D_SCALING,
// Check scaling mode
const auto& scaling_mode = input->scaling_mode;
NVTE_CHECK(scaling_mode == NVTE_MXFP8_1D_SCALING || scaling_mode == NVTE_NVFP4_1D_SCALING,
"Input tensor has invalid scaling mode (", to_string(input->scaling_mode), ").");
NVTE_CHECK(is_fp8_dtype(input->dtype()) || is_fp4_dtype(input->dtype()),
"Input tensor has invalid dtype (", to_string(input->dtype()), ").");
// Do nothing if tensor is empty
if (input->data.numel() == 0) {
return;
}
// Check tensors
CheckInputTensor(*input, "scaling_factor_input");
CheckInputTensor(*output, "scaling_factor_output");
switch (scaling_mode) {
case NVTE_MXFP8_1D_SCALING:
NVTE_CHECK(is_fp8_dtype(input->dtype()), "Input tensor has invalid dtype (expected FP8, got ",
to_string(input->dtype()), ").");
break;
case NVTE_NVFP4_1D_SCALING:
NVTE_CHECK(is_fp4_dtype(input->dtype()), "Input tensor has invalid dtype (expected FP4, got ",
to_string(input->dtype()), ").");
break;
default:
NVTE_ERROR("Invalid scaling mode");
}
auto& scaling_mode = input->scaling_mode;
NVTE_CHECK(scaling_mode == NVTE_MXFP8_1D_SCALING || scaling_mode == NVTE_NVFP4_1D_SCALING,
"Unsupported scaling mode for swizzling.");
bool nvfp4 = scaling_mode == NVTE_NVFP4_1D_SCALING;
// Check if scaling factors are non-trivial
const bool has_rowwise_scale_inv = input->scale_inv.has_data();
const bool has_columnwise_scale_inv = input->columnwise_scale_inv.has_data();
NVTE_CHECK(!has_rowwise_scale_inv || !has_columnwise_scale_inv,
"Input tensor has both row-wise and column-wise scaling factors");
if (!has_rowwise_scale_inv && !has_columnwise_scale_inv) {
return;
}
// 1D block scaling, row-wise or colum-wise
int m, k;
if (input->has_data()) {
// Deduce tensor dims
int m{0}, k{0};
switch (scaling_mode) {
case NVTE_MXFP8_1D_SCALING: {
if (has_rowwise_scale_inv) {
NVTE_CHECK(input->scale_inv.shape.size() == 2,
"Expected 2D scaling factors, got shape=", input->scale_inv.shape, ".");
m = input->scale_inv.shape[0];
k = input->scale_inv.shape[1];
} else {
if (nvfp4) {
m = input->columnwise_scale_inv.shape[0];
k = input->columnwise_scale_inv.shape[1];
} else {
} else if (has_columnwise_scale_inv) {
NVTE_CHECK(input->columnwise_scale_inv.shape.size() == 2,
"Expected 2D scaling factors, got shape=", input->columnwise_scale_inv.shape,
".");
m = input->columnwise_scale_inv.shape[1];
k = input->columnwise_scale_inv.shape[0];
}
break;
}
case NVTE_NVFP4_1D_SCALING: {
if (has_rowwise_scale_inv) {
NVTE_CHECK(input->scale_inv.shape.size() == 2,
"Expected 2D scaling factors, got shape=", input->scale_inv.shape, ".");
m = input->scale_inv.shape[0];
k = input->scale_inv.shape[1];
} else if (has_columnwise_scale_inv) {
NVTE_CHECK(input->columnwise_scale_inv.shape.size() == 2,
"Expected 2D scaling factors, got shape=", input->columnwise_scale_inv.shape,
".");
m = input->columnwise_scale_inv.shape[0];
k = input->columnwise_scale_inv.shape[1];
}
break;
}
default:
NVTE_ERROR("Invalid scaling mode");
}
// Check dims
constexpr int SF_TILE_DIM_M = 128;
constexpr int SF_TILE_DIM_K = 4;
NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!");
NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!");
NVTE_CHECK(k > 0, "Input scale inverse should be 2D!");
if (output->has_data()) {
NVTE_CHECK(m * k == std::accumulate(output->scale_inv.shape.begin(),
output->scale_inv.shape.end(), 1, std::multiplies<int>()),
"Input.scale_inv size is not equal to Output.scale_inv size!");
}
if (output->has_columnwise_data()) {
NVTE_CHECK(m * k == std::accumulate(output->columnwise_scale_inv.shape.begin(),
output->columnwise_scale_inv.shape.end(), 1,
std::multiplies<int>()),
"Input.columnwise_scale_inv size is not equal to "
"Output.columnwise_scale_inv size!");
// Check that output tensor matches input tensor
if (has_rowwise_scale_inv) {
NVTE_CHECK(output->scale_inv.has_data(),
"Output tensor does not have row-wise scaling factors.");
NVTE_CHECK(m * k == output->scale_inv.numel(), "Expected output tensor to have ", m * k,
" row-wise scaling factors, but got shape=", output->scale_inv.shape, ".");
}
if (has_columnwise_scale_inv) {
NVTE_CHECK(output->columnwise_scale_inv.has_data(),
"Output tensor does not have column-wise scaling factors.");
NVTE_CHECK(
m * k == output->columnwise_scale_inv.numel(), "Expected output tensor to have ", m * k,
" column-wise scaling factors, but got shape=", output->columnwise_scale_inv.shape, ".");
}
int num_tiles_m = m / SF_TILE_DIM_M;
int num_tiles_k = k / SF_TILE_DIM_K;
// Choose swizzle implementation
bool rowwise_swizzle{false}, columnwise_swizzle{false};
switch (scaling_mode) {
case NVTE_MXFP8_1D_SCALING: {
rowwise_swizzle = has_rowwise_scale_inv;
columnwise_swizzle = has_columnwise_scale_inv;
break;
}
case NVTE_NVFP4_1D_SCALING: {
// NVFP4 column-wise data is transposed, so row-wise and
// column-wise scales have same swizzling format
rowwise_swizzle = true;
columnwise_swizzle = false;
break;
}
default:
NVTE_ERROR("Invalid scaling mode");
}
// For NVFP4, the scale inverse for tranposed data needs rowwise swizzle.
const bool rowwise_swizzle = input->has_data() || nvfp4;
const bool columnwise_swizzle = input->has_columnwise_data() && !nvfp4;
const dim3 block_size(TB_DIM, TB_DIM);
const int num_tiles_m = m / SF_TILE_DIM_M;
const int num_tiles_k = k / SF_TILE_DIM_K;
dim3 block_size(TB_DIM, TB_DIM);
// Perform row-wise swizzle
if (rowwise_swizzle) {
int vec_load_size = (num_tiles_k - 1) % 4 + 1;
/* there is no int3 and misaligned if using int4/int2 */
......@@ -402,21 +452,33 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
dim3 num_blocks(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m);
int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t);
int original_M, original_K;
void *input_scale_inv_ptr, *output_scale_inv_ptr;
if (!nvfp4 || input->has_data()) {
int block_scale_size = nvfp4 ? NVFP4_BLOCK_SIZE : MXFP8_BLOCK_SIZE;
int original_M{0}, original_K{0};
void *input_scale_inv_ptr{nullptr}, *output_scale_inv_ptr{nullptr};
switch (scaling_mode) {
case NVTE_MXFP8_1D_SCALING: {
original_M = input->flat_first_dim();
original_K = input->flat_last_dim() / block_scale_size;
original_K = input->flat_last_dim() / MXFP8_BLOCK_SIZE;
input_scale_inv_ptr = input->scale_inv.dptr;
output_scale_inv_ptr = output->scale_inv.dptr;
} else {
break;
}
case NVTE_NVFP4_1D_SCALING: {
if (has_rowwise_scale_inv) {
original_M = input->flat_first_dim();
original_K = input->flat_last_dim() / NVFP4_BLOCK_SIZE;
input_scale_inv_ptr = input->scale_inv.dptr;
output_scale_inv_ptr = output->scale_inv.dptr;
} else if (has_columnwise_scale_inv) {
original_M = input->flat_last_dim();
original_K = input->flat_first_dim() / NVFP4_BLOCK_SIZE;
input_scale_inv_ptr = input->columnwise_scale_inv.dptr;
output_scale_inv_ptr = output->columnwise_scale_inv.dptr;
}
break;
}
default:
NVTE_ERROR("Invalid scaling mode");
}
switch (vec_load_size) {
case 4:
......@@ -447,7 +509,10 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
NVTE_ERROR("Not valid vec_load_size.");
break;
}
NVTE_CHECK_CUDA(cudaGetLastError());
}
// Perform column-wise swizzle
if (columnwise_swizzle) {
int vec_load_size = (num_tiles_m - 1) % 4 + 1;
if (vec_load_size == 3) vec_load_size = 1; /* no int3 and misaligned if using int4/int2 */
......@@ -456,8 +521,6 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t);
const int original_M = input->flat_last_dim();
const int original_K = input->flat_first_dim() / MXFP8_BLOCK_SIZE;
// NVFP4 shouldn't end up here because it only needs rowwise swizzle
NVTE_CHECK(!nvfp4, "NVFP4 shouldn't end up here because it only needs rowwise swizzle");
switch (vec_load_size) {
case 4:
......@@ -491,9 +554,8 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
NVTE_ERROR("Not valid vec_load_size.");
break;
}
}
NVTE_CHECK_CUDA(cudaGetLastError());
}
}
template <int SF_TILE_DIM_M, int SF_TILE_DIM_K>
......@@ -595,17 +657,18 @@ void multi_tensor_swizzle_scaling_factors(const std::vector<Tensor*>& input,
(is_fp8 && is_mxfp8_scaling(scaling_mode)) || (is_fp4 && is_nvfp4_scaling(scaling_mode)),
"Not implemented scaling mode " + to_string(scaling_mode) + ".");
// We don't allow empty tensors. They should be filtered out before calling this function.
if (input[i]->data.numel() == 0) {
NVTE_ERROR("Tensor input[" + std::to_string(i) + "] is empty.");
}
NVTE_CHECK(input[i]->numel() != 0, "Tensor input[", i, "] is empty.");
CheckInputTensor(*input[i], "scaling_factor_input[" + std::to_string(i) + "]");
CheckInputTensor(*output[i], "scaling_factor_output[" + std::to_string(i) + "]");
all_has_data &= input[i]->has_data();
all_has_columnwise_data &= input[i]->has_columnwise_data();
all_nvfp4 &= is_nvfp4_scaling(scaling_mode);
all_has_data = all_has_data && input[i]->scale_inv.has_data();
all_has_columnwise_data =
(all_has_columnwise_data && input[i]->columnwise_scale_inv.has_data());
all_nvfp4 = all_nvfp4 && is_nvfp4_scaling(scaling_mode);
}
NVTE_CHECK(all_has_data || all_has_columnwise_data,
"All tensors should have data or columnwise data.");
NVTE_CHECK(!all_has_data || !all_has_columnwise_data,
"All tensors have both data and columnwise data.");
const bool rowwise_swizzle = all_has_data || all_nvfp4;
const bool columnwise_swizzle = all_has_columnwise_data && !all_nvfp4;
......@@ -644,18 +707,19 @@ void multi_tensor_swizzle_scaling_factors(const std::vector<Tensor*>& input,
NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!");
NVTE_CHECK(k > 0, "Input scale inverse should be 2D!");
if (output[i]->has_data()) {
NVTE_CHECK(
m * k == std::accumulate(output[i]->scale_inv.shape.begin(),
output[i]->scale_inv.shape.end(), 1, std::multiplies<int>()),
"Input.scale_inv size is not equal to Output.scale_inv size!");
}
if (output[i]->has_columnwise_data()) {
NVTE_CHECK(m * k == std::accumulate(output[i]->columnwise_scale_inv.shape.begin(),
output[i]->columnwise_scale_inv.shape.end(), 1,
std::multiplies<int>()),
"Input.columnwise_scale_inv size is not equal to "
"Output.columnwise_scale_inv size!");
if (all_has_data) {
NVTE_CHECK(output[i]->scale_inv.has_data(), "Output tensor ", i,
" does not have row-wise scaling factors.");
NVTE_CHECK(m * k == output[i]->scale_inv.numel(), "Expected output tensor ", i, " to have ",
m * k, " row-wise scaling factors, but got shape=", output[i]->scale_inv.shape,
".");
}
if (all_has_columnwise_data) {
NVTE_CHECK(output[i]->columnwise_scale_inv.has_data(), "Output tensor ", i,
" does not have column-wise scaling factors.");
NVTE_CHECK(m * k == output[i]->columnwise_scale_inv.numel(), "Expected output tensor ", i,
" to have ", m * k, " column-wise scaling factors, but got shape=",
output[i]->columnwise_scale_inv.shape, ".");
}
int num_tiles_k = k / SF_TILE_DIM_K;
......
......@@ -77,7 +77,7 @@ std::string to_string(const NVTEScalingMode &mode) {
}
void CheckNoopTensor(const Tensor &t, const std::string &name) {
if (t.data.dptr != nullptr) {
if (t.data.has_data()) {
NVTE_CHECK(t.numel() == 1, "Expected 1 element for ", name, " noop, but found ", t.numel(),
".");
NVTE_CHECK(t.data.dtype == DType::kFloat32, "Found wrong dtype for ", name,
......@@ -88,16 +88,31 @@ void CheckNoopTensor(const Tensor &t, const std::string &name) {
void CheckScaleTensorShape(const Tensor &t, const std::string &name) {
NVTE_CHECK(t.scaling_mode != NVTE_INVALID_SCALING, "Invalid scaling mode!");
if (is_tensor_scaling(t.scaling_mode)) {
// per-tensor scaling
if (is_fp8_dtype(t.dtype())) {
// FP8 tensor with tensor scaling
if (t.has_data()) {
NVTE_CHECK(t.scale_inv.numel() == 1, "Tensor \"", name,
"\" has invalid scale_inv shape (expected (1), got ", t.scale_inv.shape, ")");
"\" has invalid scale_inv shape (expected 1 entry, got ", t.scale_inv.shape,
")");
}
if (t.has_columnwise_data()) {
NVTE_CHECK(t.columnwise_scale_inv.numel() == 1, "Tensor \"", name,
"\" has invalid columnwise_scale_inv shape (expected (1), got ",
"\" has invalid columnwise_scale_inv shape (expected 1 entry, got ",
t.columnwise_scale_inv.shape, ")");
}
} else {
// High-precision tensor
if (t.has_data()) {
NVTE_CHECK(t.scale_inv.numel() == 0, "Tensor \"", name,
"\" has invalid scale_inv shape (expected 0 entries, got ", t.scale_inv.shape,
")");
}
if (t.has_columnwise_data()) {
NVTE_CHECK(t.columnwise_scale_inv.numel() == 0, "Tensor \"", name,
"\" has invalid columnwise_scale_inv shape (expected 0 entries, got ",
t.columnwise_scale_inv.shape, ")");
}
}
} else {
if (t.scaling_mode == NVTE_MXFP8_1D_SCALING) {
// Need (4, 128) alignment even for e8 scaling factor
......@@ -159,7 +174,7 @@ void CheckInputTensor(const Tensor &t, const std::string &name) {
if (is_fp8_dtype(type)) {
// FP8 input needs to have scale_inv
if (t.has_data()) {
NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP8 scaling factor input ", name,
NVTE_CHECK(t.scale_inv.has_data(), "FP8 scaling factor input ", name,
"_scale_inverse must be allocated");
NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32 || t.scale_inv.dtype == DType::kFloat8E8M0,
"FP8 scaling factor input ", name,
......@@ -168,7 +183,7 @@ void CheckInputTensor(const Tensor &t, const std::string &name) {
to_string(t.scale_inv.dtype), ")");
}
if (t.has_columnwise_data()) {
NVTE_CHECK(t.columnwise_scale_inv.dptr != nullptr, "FP8 scaling factor input ", name,
NVTE_CHECK(t.columnwise_scale_inv.has_data(), "FP8 scaling factor input ", name,
"_columnwise_scale_inverse must be allocated");
NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat32 ||
t.columnwise_scale_inv.dtype == DType::kFloat8E8M0,
......@@ -181,7 +196,7 @@ void CheckInputTensor(const Tensor &t, const std::string &name) {
// TODO(ksivaman): Fix this to check for amaxes and other details.
// For now only needed for swizzle.
if (t.has_data()) {
NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP4 scaling factor input ", name,
NVTE_CHECK(t.scale_inv.has_data(), "FP4 scaling factor input ", name,
"_scale_inverse must be allocated");
NVTE_CHECK(t.scale_inv.dtype == DType::kFloat8E4M3, "FP4 scaling factor input ", name,
"_scale_inverse has invalid dtype "
......@@ -189,7 +204,7 @@ void CheckInputTensor(const Tensor &t, const std::string &name) {
to_string(t.scale_inv.dtype), ")");
}
if (t.has_columnwise_data()) {
NVTE_CHECK(t.columnwise_scale_inv.dptr != nullptr, "FP4 scaling factor input ", name,
NVTE_CHECK(t.columnwise_scale_inv.has_data(), "FP4 scaling factor input ", name,
"_columnwise_scale_inverse must be allocated");
NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat8E4M3, "FP8 scaling factor input ",
name,
......@@ -198,11 +213,10 @@ void CheckInputTensor(const Tensor &t, const std::string &name) {
to_string(t.columnwise_scale_inv.dtype), ")");
}
} else {
NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 input ", name);
NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 input ", name);
NVTE_CHECK(t.scale_inv.dptr == nullptr, "Scale_inv is not supported for non-FP8 input ", name);
NVTE_CHECK(t.columnwise_scale_inv.dptr == nullptr,
"Scale_inv is not supported for non-FP8 input ", name);
NVTE_CHECK(!t.scale.has_data(), "Scale is not supported for non-FP8 input ", name);
NVTE_CHECK(!t.scale_inv.has_data(), "Scale_inv is not supported for non-FP8 input ", name);
NVTE_CHECK(!t.columnwise_scale_inv.has_data(), "Scale_inv is not supported for non-FP8 input ",
name);
}
NVTE_CHECK(t.has_data() || t.has_columnwise_data(), "Input ", name, " is not allocated!");
......@@ -213,14 +227,14 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
const DType type = t.dtype();
if (is_fp8_dtype(type)) {
// FP8 output needs to have scale, scale_inv and (if delayed scaling) amax
if (t.scaling_mode == NVTE_DELAYED_TENSOR_SCALING && t.amax.dptr != nullptr) {
if (t.scaling_mode == NVTE_DELAYED_TENSOR_SCALING && t.amax.has_data()) {
NVTE_CHECK(t.amax.dtype == DType::kFloat32, "Invalid amax dtype (expected ",
to_string(DType::kFloat32), ", got ", to_string(t.amax.dtype), ")");
NVTE_CHECK(product(t.amax.shape) == 1, "Invalid shape of amax in output ", name,
NVTE_CHECK(t.amax.numel() == 1, "Invalid shape of amax in output ", name,
" (expected 1 entry, got shape=", t.amax.shape, ")");
}
if (t.has_data()) {
NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP8 scaling factor output ", name,
NVTE_CHECK(t.scale_inv.has_data(), "FP8 scaling factor output ", name,
"_scale_inverse must be allocated");
NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32 || t.scale_inv.dtype == DType::kFloat8E8M0,
"FP8 scaling factor output ", name,
......@@ -229,7 +243,7 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
to_string(t.scale_inv.dtype), ")");
}
if (t.has_columnwise_data()) {
NVTE_CHECK(t.columnwise_scale_inv.dptr != nullptr, "FP8 scaling factor output ", name,
NVTE_CHECK(t.columnwise_scale_inv.has_data(), "FP8 scaling factor output ", name,
"_columnwise_scale_inverse must be allocated");
NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat32 ||
t.columnwise_scale_inv.dtype == DType::kFloat8E8M0,
......@@ -241,7 +255,7 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
} else if (is_fp4_dtype(type)) {
// FP4 output needs to have the scale_inv
if (t.has_data()) {
NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP4 scaling factor output ", name,
NVTE_CHECK(t.scale_inv.has_data(), "FP4 scaling factor output ", name,
"_scale_inverse must be allocated");
NVTE_CHECK(t.scale_inv.dtype == DType::kFloat8E4M3, "FP4 scaling factor output ", name,
"_scale_inverse has invalid dtype "
......@@ -249,7 +263,7 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
to_string(t.scale_inv.dtype), ")");
}
if (t.has_columnwise_data()) {
NVTE_CHECK(t.columnwise_scale_inv.dptr != nullptr, "FP4 scaling factor output ", name,
NVTE_CHECK(t.columnwise_scale_inv.has_data(), "FP4 scaling factor output ", name,
"_columnwise_scale_inverse must be allocated");
NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat8E4M3, "FP4 scaling factor output ",
name,
......@@ -258,12 +272,10 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
to_string(t.columnwise_scale_inv.dtype), ")");
}
} else {
NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 output ", name);
// Unfused quant with level 2 nvfp4 scaling will produce high precision tensors with amax.
// NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 output ", name);
NVTE_CHECK(t.scale_inv.dptr == nullptr, "Scale_inv is not supported for non-FP8 output ", name);
NVTE_CHECK(t.columnwise_scale_inv.dptr == nullptr,
"Scale_inv is not supported for non-FP8 input ", name);
NVTE_CHECK(!t.scale.has_data(), "Scale is not supported for non-FP8 output ", name);
NVTE_CHECK(!t.scale_inv.has_data(), "Scale_inv is not supported for non-FP8 output ", name);
NVTE_CHECK(!t.columnwise_scale_inv.has_data(), "Scale_inv is not supported for non-FP8 input ",
name);
}
if (!allow_empty) {
......@@ -622,7 +634,11 @@ NVTEShape nvte_make_shape(const size_t *data, size_t ndim) {
NVTE_CHECK(ndim <= sizeof(ret.data) / sizeof(ret.data[0]),
"Too many dims for NVTEShape (requested: ", ndim,
", max: ", sizeof(ret.data) / sizeof(ret.data[0]), ")");
if (data == nullptr) {
std::fill(ret.data, ret.data + ndim, 0);
} else {
std::copy(data, data + ndim, ret.data);
}
ret.ndim = ndim;
return ret;
}
......@@ -729,7 +745,7 @@ void *nvte_tensor_columnwise_scale_inv(const NVTETensor tensor) {
NVTEShape nvte_tensor_scale_inv_shape(const NVTETensor tensor) {
auto *t = transformer_engine::convertNVTETensor(tensor);
if (t == nullptr) {
return nvte_make_shape(nullptr, 0);
return nvte_make_shape(nullptr, 1);
}
return nvte_make_shape(t->scale_inv.shape.data(), t->scale_inv.shape.size());
}
......@@ -768,7 +784,7 @@ void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name,
NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam param_name) {
if (tensor == nullptr) {
return {nullptr, kNVTEFloat32, nvte_make_shape(nullptr, 0)};
return {nullptr, kNVTEFloat32, nvte_make_shape(nullptr, 1)};
}
const auto &t = *transformer_engine::convertNVTETensorCheck(tensor);
switch (param_name) {
......@@ -813,14 +829,21 @@ void nvte_tensor_pack_destroy(NVTETensorPack *pack) {
void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream) {
if (tensor == nullptr) return;
const auto &t = *transformer_engine::convertNVTETensorCheck(tensor);
// Zero out tensor data if allocated
if (t.data.dptr != nullptr) {
const size_t size_in_bytes = nvte_tensor_size_bytes(tensor);
NVTE_CHECK_CUDA(cudaMemsetAsync(t.data.dptr, 0, size_in_bytes, stream));
const auto size = t.data.buffer_size_bytes();
if (size > 0) {
NVTE_CHECK_CUDA(cudaMemsetAsync(t.data.dptr, 0, size, stream));
}
// Set amax to 0 if allocated
}
// Zero out amax if allocated
if (t.amax.dptr != nullptr) {
NVTE_CHECK_CUDA(cudaMemsetAsync(t.amax.dptr, 0, sizeof(float), stream));
const auto size = t.amax.buffer_size_bytes();
if (size > 0) {
NVTE_CHECK_CUDA(cudaMemsetAsync(t.amax.dptr, 0, size, stream));
}
}
}
......@@ -1007,7 +1030,7 @@ void nvte_set_grouped_tensor_param(NVTEGroupedTensor *tensor, NVTEGroupedTensorP
NVTEBasicTensor nvte_get_grouped_tensor_param(const NVTEGroupedTensor tensor,
NVTEGroupedTensorParam param_name) {
if (tensor == nullptr) {
return {nullptr, kNVTEFloat32, nvte_make_shape(nullptr, 0)};
return {nullptr, kNVTEFloat32, nvte_make_shape(nullptr, 1)};
}
const auto &t = *transformer_engine::convertNVTEGroupedTensorCheck(tensor);
......@@ -1059,7 +1082,7 @@ NVTEScalingMode nvte_grouped_tensor_scaling_mode(const NVTEGroupedTensor tensor)
NVTEShape nvte_get_grouped_tensor_logical_shape(const NVTEGroupedTensor tensor) {
if (tensor == nullptr) {
return nvte_make_shape(nullptr, 0);
return nvte_make_shape(nullptr, 1);
}
const auto &t = *transformer_engine::convertNVTEGroupedTensorCheck(tensor);
return t.logical_shape;
......
......@@ -198,8 +198,6 @@ void populate_cast_transpose_dbias_workspace_config(const Tensor &cast_output, /
workspace->data.dtype);
const size_t required_size =
get_buffer_size_bytes(num_rows_partial_dbias, row_length, DType::kFloat32);
NVTE_CHECK(!workspace->data.shape.empty(), "Invalid workspace dims (expected (",
num_rows_partial_dbias, ",", row_length, "), found ())");
NVTE_CHECK(workspace_size >= required_size, "Invalid workspace (expected dims=(",
num_rows_partial_dbias, ",", row_length, "), dtype=", to_string(DType::kFloat32),
"; found dims=", workspace->data.shape,
......
......@@ -388,8 +388,6 @@ void populate_transpose_dbias_workspace_config(const Tensor &input, /*cast*/
workspace->data.dtype);
const size_t required_size =
get_buffer_size_bytes(num_rows_partial_dbias, row_length, DType::kFloat32);
NVTE_CHECK(!workspace->data.shape.empty(), "Invalid workspace dims (expected (",
num_rows_partial_dbias, ",", row_length, "), found ())");
NVTE_CHECK(workspace_size >= required_size, "Invalid workspace (expected dims=(",
num_rows_partial_dbias, ",", row_length, "), dtype=", to_string(DType::kFloat32),
"; found dims=", workspace->data.shape,
......
......@@ -334,12 +334,12 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_fp
tensor_cpp_list.emplace_back(makeTransformerEngineTensor(
rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr,
columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr,
rowwise_usage ? rowwise_data_shapes[i] : std::vector<size_t>{},
columnwise_usage ? columnwise_data_shapes[i] : std::vector<size_t>{}, fp8_dtype, nullptr,
rowwise_usage ? rowwise_data_shapes[i] : std::vector<size_t>{0},
columnwise_usage ? columnwise_data_shapes[i] : std::vector<size_t>{0}, fp8_dtype, nullptr,
nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr,
columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr,
rowwise_usage ? rowwise_scale_shapes[i] : std::vector<size_t>{},
columnwise_usage ? columnwise_scale_shapes[i] : std::vector<size_t>{}, scaling_mode));
rowwise_usage ? rowwise_scale_shapes[i] : std::vector<size_t>{0},
columnwise_usage ? columnwise_scale_shapes[i] : std::vector<size_t>{0}, scaling_mode));
}
return retval;
......@@ -481,12 +481,12 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>> bulk_allocate_mx
tensor_cpp_list.emplace_back(makeTransformerEngineTensor(
rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr,
columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr,
rowwise_usage ? rowwise_data_shapes[i] : std::vector<size_t>{},
columnwise_usage ? columnwise_data_shapes[i] : std::vector<size_t>{}, fp8_dtype, nullptr,
rowwise_usage ? rowwise_data_shapes[i] : std::vector<size_t>{0},
columnwise_usage ? columnwise_data_shapes[i] : std::vector<size_t>{0}, fp8_dtype, nullptr,
nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr,
columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr,
rowwise_usage ? rowwise_scale_shapes[i] : std::vector<size_t>{},
columnwise_usage ? columnwise_scale_shapes[i] : std::vector<size_t>{}, scaling_mode));
rowwise_usage ? rowwise_scale_shapes[i] : std::vector<size_t>{0},
columnwise_usage ? columnwise_scale_shapes[i] : std::vector<size_t>{0}, scaling_mode));
}
return retval;
......@@ -685,13 +685,13 @@ std::tuple<std::vector<py::object>, std::vector<TensorWrapper>, bool> bulk_alloc
auto tensor_wrapper = makeTransformerEngineTensor(
rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr,
columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr,
rowwise_usage ? rowwise_data_shapes[i] : std::vector<size_t>{},
columnwise_usage ? columnwise_data_shapes[i] : std::vector<size_t>{}, fp4_dtype,
rowwise_usage ? rowwise_data_shapes[i] : std::vector<size_t>{0},
columnwise_usage ? columnwise_data_shapes[i] : std::vector<size_t>{0}, fp4_dtype,
/*amax_ptr=*/nullptr,
/*scale_ptr=*/nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr,
columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr,
rowwise_usage ? rowwise_scale_shapes[i] : std::vector<size_t>{},
columnwise_usage ? columnwise_scale_shapes[i] : std::vector<size_t>{}, scaling_mode);
rowwise_usage ? rowwise_scale_shapes[i] : std::vector<size_t>{0},
columnwise_usage ? columnwise_scale_shapes[i] : std::vector<size_t>{0}, scaling_mode);
// Set the amax rowwise and amax columnwise if available
if (rowwise_usage) {
......
......@@ -43,10 +43,10 @@ bool is_low_precision(const DType type) {
std::vector<size_t> getGemmOutputShape(const NVTEShape& A_shape, const bool transa,
const NVTEShape& B_shape, const bool transb) {
// Flatten outer dims to get 2D matrices
const size_t A0 = product(A_shape, 0, A_shape.ndim - 1);
const size_t A1 = A_shape.data[A_shape.ndim - 1];
const size_t B0 = product(B_shape, 0, B_shape.ndim - 1);
const size_t B1 = B_shape.data[B_shape.ndim - 1];
const size_t A0 = A_shape.ndim > 0 ? product(A_shape, 0, A_shape.ndim - 1) : 1;
const size_t A1 = A_shape.ndim > 0 ? A_shape.data[A_shape.ndim - 1] : 1;
const size_t B0 = B_shape.ndim > 0 ? product(B_shape, 0, B_shape.ndim - 1) : 1;
const size_t B1 = B_shape.ndim > 0 ? B_shape.data[B_shape.ndim - 1] : 1;
// Check matrix dims
NVTE_CHECK((transa ? A1 : A0) == (transb ? B0 : B1), "Invalid matrix dimensions for GEMM (A=(",
......
......@@ -22,8 +22,8 @@ void compute_amax(const at::Tensor& tensor, at::Tensor& amax) {
TORCH_CHECK(amax.numel() == 1, "amax must have exactly one element");
auto* amax_ptr = amax.data_ptr<float>();
TensorWrapper fake_te_output(
nullptr, te_input.shape(),
DType::kFloat8E4M3, // It doesn't matter because we only compute amax.
/*dptr=*/nullptr, te_input.shape(),
DType::kFloat32, // It doesn't matter because we only compute amax.
amax_ptr);
nvte_compute_amax(te_input.data(), fake_te_output.data(), at::cuda::getCurrentCUDAStream());
......
......@@ -142,13 +142,20 @@ std::optional<at::Tensor> multi_tensor_swizzle_scaling_factors(
auto& tensor = tensors[i];
void* scale_inv_dptr = scale_inv_dptrs[i];
void* swizzled_scale_inv_dptr = getDataPtr(buffer, scale_inv_offsets[i]);
// auto input_shape = nvte_shape_to_vector(tensor.shape());
// Empty tensors don't require scale swizzling
if (tensor.numel() == 0) {
continue;
}
// Tensor shape
NVTEShape nvte_input_shape;
if (rowwise) {
nvte_input_shape = tensor.shape();
} else {
nvte_input_shape = tensor.get_columnwise_data().shape;
}
auto input_shape = nvte_shape_to_vector(nvte_input_shape);
// Reconstruct input only to avoid swizzling both directions if not needed.
// Use any 8 bit type, it's irrelevant.
......@@ -202,14 +209,14 @@ at::Tensor convert_block_scaling_to_mxfp8_tensor(transformer_engine::TensorWrapp
size_t data_flat_last_dim = 1;
if (rowwise) {
data = input.get_rowwise_data();
for (int i = 0; i < data.shape.ndim - 1; ++i) {
for (size_t i = 0; i < data.shape.ndim - 1; ++i) {
data_flat_first_dim *= data.shape.data[i];
}
data_flat_last_dim = data.shape.data[data.shape.ndim - 1];
} else {
data = input.get_columnwise_data();
data_flat_first_dim = data.shape.data[0];
for (int i = 1; i < data.shape.ndim; ++i) {
for (size_t i = 1; i < data.shape.ndim; ++i) {
data_flat_last_dim *= data.shape.data[i];
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment