Unverified Commit 91405eb4 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Revert "Allow NVTEShape to own data." (#1703)

Revert "Allow NVTEShape to own data. (#1674)"

This reverts commit e61ce77c

.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 4742c0f8
...@@ -112,8 +112,8 @@ struct scale_inv_meta { ...@@ -112,8 +112,8 @@ struct scale_inv_meta {
size_t type_size; size_t type_size;
}; };
NVTEShape convertShape(const std::vector<size_t>& s) { NVTEShape convertShape(const std::vector<size_t>& shape) {
return nvte_make_shape(s.data(), s.size()); return {shape.data(), shape.size()};
} }
std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape, std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
...@@ -240,7 +240,7 @@ Tensor::Tensor(const std::string& name, ...@@ -240,7 +240,7 @@ Tensor::Tensor(const std::string& name,
std::vector<size_t> normalized_shape_v = {product(shape, 0, shape.ndim - 1), std::vector<size_t> normalized_shape_v = {product(shape, 0, shape.ndim - 1),
shape.data[shape.ndim - 1]}; shape.data[shape.ndim - 1]};
NVTEShape normalized_shape = convertShape(normalized_shape_v); NVTEShape normalized_shape = convertShape(normalized_shape_v);
NVTEShape columnwise_shape = {}; NVTEShape columnwise_shape{nullptr, 0};
std::vector<size_t> columnwise_shape_vec; std::vector<size_t> columnwise_shape_vec;
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING || scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) { if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING || scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) {
...@@ -257,7 +257,8 @@ Tensor::Tensor(const std::string& name, ...@@ -257,7 +257,8 @@ Tensor::Tensor(const std::string& name,
} }
if (columnwise) { if (columnwise) {
columnwise_shape = nvte_make_shape(columnwise_shape_vec.data(), columnwise_shape_vec.size()); columnwise_shape.data = columnwise_shape_vec.data();
columnwise_shape.ndim = columnwise_shape_vec.size();
} }
tensor_ = TensorWrapper(scaling_mode); tensor_ = TensorWrapper(scaling_mode);
......
...@@ -109,7 +109,7 @@ class Tensor { ...@@ -109,7 +109,7 @@ class Tensor {
const bool rowwise = true, const bool rowwise = true,
const bool columnwise = false, const bool columnwise = false,
const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING) : const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING) :
Tensor(name, nvte_make_shape(shape.data(), shape.size()), type, rowwise, columnwise, mode) {} Tensor(name, NVTEShape{shape.data(), shape.size()}, type, rowwise, columnwise, mode) {}
Tensor() {} Tensor() {}
......
...@@ -78,8 +78,8 @@ struct SimpleTensor { ...@@ -78,8 +78,8 @@ struct SimpleTensor {
SimpleTensor() : SimpleTensor(nullptr, {}, DType::kFloat32) {} SimpleTensor() : SimpleTensor(nullptr, {}, DType::kFloat32) {}
operator NVTEBasicTensor() const { operator NVTEBasicTensor() const {
return {dptr, static_cast<NVTEDType>(dtype), const NVTEShape shape = {this->shape.data(), this->shape.size()};
nvte_make_shape(this->shape.data(), this->shape.size())}; return {dptr, static_cast<NVTEDType>(dtype), shape};
} }
int numel() const { int numel() const {
...@@ -99,6 +99,11 @@ struct Tensor { ...@@ -99,6 +99,11 @@ struct Tensor {
SimpleTensor scale_inv; SimpleTensor scale_inv;
SimpleTensor columnwise_scale_inv; SimpleTensor columnwise_scale_inv;
private:
// Used as an allocation for nvte_tensor_shape
// if the shape has to be inferred from columnwise data.
mutable std::vector<size_t> rowwise_shape_cache;
public: public:
NVTEScalingMode scaling_mode; NVTEScalingMode scaling_mode;
...@@ -189,6 +194,22 @@ struct Tensor { ...@@ -189,6 +194,22 @@ struct Tensor {
} }
} }
const std::vector<size_t> &rowwise_shape_ref() const {
auto shape_queried = shape();
// This method is primarily designed for nvte_shape.
// An unfortunate consequence of unconditionally assigning
// values to rowwise_shape_cache without a check is that
// repeated calls to rowwise_shape_ref are likely to
// invalidate the data pointers from previous calls.
// If the shape has changed, then invalidating is necessary
// in at least some cases, but we want to keep the data
// valid otherwise.
if (rowwise_shape_cache != shape_queried) {
rowwise_shape_cache = std::move(shape_queried);
}
return rowwise_shape_cache;
}
/*! Matrix height after tensor is flattened to 2D /*! Matrix height after tensor is flattened to 2D
* *
* If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted * If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted
......
...@@ -42,8 +42,6 @@ struct NVTEShape { ...@@ -42,8 +42,6 @@ struct NVTEShape {
const size_t *data; const size_t *data;
/*! \brief Number of dimensions. */ /*! \brief Number of dimensions. */
size_t ndim; size_t ndim;
/*! \brief Copy of data. Num dims limited to permit fixed struct size.*/
size_t owned_data[14];
}; };
/*! \struct NVTEBasicTensor /*! \struct NVTEBasicTensor
...@@ -136,15 +134,6 @@ void *nvte_tensor_data(const NVTETensor tensor); ...@@ -136,15 +134,6 @@ void *nvte_tensor_data(const NVTETensor tensor);
*/ */
void *nvte_tensor_columnwise_data(const NVTETensor tensor); void *nvte_tensor_columnwise_data(const NVTETensor tensor);
/*! \brief Construct a shape from an array of dimension sizes.
*
* \param[data] Pointer to start of shape array.
* \param[data] Number of dimensions (must be <= 14)
*
* \return A shape. The shape will own its own copy of the data.
*/
NVTEShape nvte_make_shape(const size_t *data, size_t ndim);
/*! \brief Get a tensor's data shape. /*! \brief Get a tensor's data shape.
* *
* \param[in] tensor Tensor. * \param[in] tensor Tensor.
...@@ -428,9 +417,8 @@ class TensorWrapper { ...@@ -428,9 +417,8 @@ class TensorWrapper {
float *amax_dptr = nullptr, float *scale_dptr = nullptr, float *amax_dptr = nullptr, float *scale_dptr = nullptr,
float *scale_inv_dptr = nullptr, const std::vector<size_t> &scale_inv_shape = {1}, float *scale_inv_dptr = nullptr, const std::vector<size_t> &scale_inv_shape = {1},
const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING)
: TensorWrapper(dptr, nvte_make_shape(shape.data(), shape.size()), dtype, amax_dptr, : TensorWrapper(dptr, NVTEShape{shape.data(), shape.size()}, dtype, amax_dptr, scale_dptr,
scale_dptr, scale_inv_dptr, scale_inv_dptr, NVTEShape{scale_inv_shape.data(), scale_inv_shape.size()},
nvte_make_shape(scale_inv_shape.data(), scale_inv_shape.size()),
scaling_mode) {} scaling_mode) {}
/*! \brief Constructs new empty TensorWrapper. /*! \brief Constructs new empty TensorWrapper.
...@@ -546,9 +534,7 @@ class TensorWrapper { ...@@ -546,9 +534,7 @@ class TensorWrapper {
* \return Shape of this TensorWrapper. * \return Shape of this TensorWrapper.
*/ */
const NVTEShape shape() const noexcept { const NVTEShape shape() const noexcept {
if (tensor_ == nullptr) { if (tensor_ == nullptr) return NVTEShape{nullptr, 0};
return nvte_make_shape(nullptr, 0);
}
return nvte_tensor_shape(tensor_); return nvte_tensor_shape(tensor_);
} }
...@@ -557,9 +543,7 @@ class TensorWrapper { ...@@ -557,9 +543,7 @@ class TensorWrapper {
* \return Shape of this TensorWrapper. * \return Shape of this TensorWrapper.
*/ */
const NVTEShape columnwise_shape() const noexcept { const NVTEShape columnwise_shape() const noexcept {
if (tensor_ == nullptr) { if (tensor_ == nullptr) return NVTEShape{nullptr, 0};
return nvte_make_shape(nullptr, 0);
}
return nvte_tensor_columnwise_shape(tensor_); return nvte_tensor_columnwise_shape(tensor_);
} }
...@@ -672,9 +656,7 @@ class TensorWrapper { ...@@ -672,9 +656,7 @@ class TensorWrapper {
* \return scale_inv_shape of this TensorWrapper. * \return scale_inv_shape of this TensorWrapper.
*/ */
const NVTEShape scale_inv_shape() const noexcept { const NVTEShape scale_inv_shape() const noexcept {
if (tensor_ == nullptr) { if (tensor_ == nullptr) return NVTEShape{nullptr, 0};
return nvte_make_shape(nullptr, 0);
}
return nvte_tensor_scale_inv_shape(tensor_); return nvte_tensor_scale_inv_shape(tensor_);
} }
...@@ -690,20 +672,12 @@ class TensorWrapper { ...@@ -690,20 +672,12 @@ class TensorWrapper {
void zero_(cudaStream_t stream) { nvte_zero_tensor(tensor_, stream); } void zero_(cudaStream_t stream) { nvte_zero_tensor(tensor_, stream); }
static constexpr size_t defaultData = 1; static constexpr size_t defaultData = 1;
static constexpr NVTEShape defaultShape = { static constexpr NVTEShape defaultShape = {&defaultData, 1};
&defaultData, 1, {defaultData, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}};
private: private:
NVTEShape convertShape(const NVTEShape &s) { NVTEShape convertShape(const NVTEShape &s) { return s; }
NVTEShape ret = s;
// Move the ownership rather than pointing to the parent shape.
ret.data = ret.owned_data;
return ret;
}
NVTEShape convertShape(const std::vector<size_t> &s) { NVTEShape convertShape(const std::vector<size_t> &s) { return {s.data(), s.size()}; }
return nvte_make_shape(s.data(), s.size());
}
/*! \brief Wrapped NVTETensor. */ /*! \brief Wrapped NVTETensor. */
NVTETensor tensor_ = nullptr; NVTETensor tensor_ = nullptr;
......
...@@ -211,22 +211,6 @@ NVTEDType nvte_tensor_type(const NVTETensor tensor) { ...@@ -211,22 +211,6 @@ NVTEDType nvte_tensor_type(const NVTETensor tensor) {
reinterpret_cast<const transformer_engine::Tensor *>(tensor)->dtype()); reinterpret_cast<const transformer_engine::Tensor *>(tensor)->dtype());
} }
NVTEShape nvte_make_shape(const size_t *data, size_t ndim) {
NVTEShape ret;
if (ndim == 0) {
ret.data = nullptr;
ret.ndim = 0;
return ret;
}
NVTE_CHECK(ndim <= sizeof(ret.owned_data) / sizeof(ret.owned_data[0]),
"Too many dims for NVTEShape (requested: ", ndim,
", max: ", sizeof(ret.owned_data) / sizeof(ret.owned_data[0]), ")");
std::copy(data, data + ndim, ret.owned_data);
ret.data = ret.owned_data;
ret.ndim = ndim;
return ret;
}
NVTEShape nvte_tensor_shape(const NVTETensor tensor) { NVTEShape nvte_tensor_shape(const NVTETensor tensor) {
if (tensor == nullptr) { if (tensor == nullptr) {
NVTE_ERROR("Invalid tensor"); NVTE_ERROR("Invalid tensor");
...@@ -234,9 +218,12 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor) { ...@@ -234,9 +218,12 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor) {
// Determine tensor shape depending on tensor format // Determine tensor shape depending on tensor format
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor); const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
std::vector<size_t> shape = t.shape(); const std::vector<size_t> &rowwise_shape = t.rowwise_shape_ref();
return nvte_make_shape(shape.data(), shape.size()); NVTEShape ret;
ret.data = rowwise_shape.data();
ret.ndim = rowwise_shape.size();
return ret;
} }
NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) { NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) {
...@@ -244,7 +231,10 @@ NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) { ...@@ -244,7 +231,10 @@ NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) {
NVTE_ERROR("Invalid tensor"); NVTE_ERROR("Invalid tensor");
} }
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor); const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
return nvte_make_shape(t.columnwise_data.shape.data(), t.columnwise_data.shape.size()); NVTEShape ret;
ret.data = t.columnwise_data.shape.data();
ret.ndim = t.columnwise_data.shape.size();
return ret;
} }
size_t nvte_tensor_ndims(const NVTETensor tensor) { return nvte_tensor_shape(tensor).ndim; } size_t nvte_tensor_ndims(const NVTETensor tensor) { return nvte_tensor_shape(tensor).ndim; }
...@@ -312,11 +302,12 @@ void *nvte_tensor_columnwise_scale_inv(const NVTETensor tensor) { ...@@ -312,11 +302,12 @@ void *nvte_tensor_columnwise_scale_inv(const NVTETensor tensor) {
} }
NVTEShape nvte_tensor_scale_inv_shape(const NVTETensor tensor) { NVTEShape nvte_tensor_scale_inv_shape(const NVTETensor tensor) {
if (tensor == nullptr) { if (tensor == nullptr) return {nullptr, 0};
return nvte_make_shape(nullptr, 0);
}
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor); const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
return nvte_make_shape(t.scale_inv.shape.data(), t.scale_inv.shape.size()); NVTEShape ret;
ret.data = t.scale_inv.shape.data();
ret.ndim = t.scale_inv.shape.size();
return ret;
} }
void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name, void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name,
......
...@@ -3,11 +3,9 @@ ...@@ -3,11 +3,9 @@
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "extensions.h" #include "extensions.h"
#include "kv_cache.cuh" #include "kv_cache.cuh"
#include "thd_utils.cuh" #include "thd_utils.cuh"
#include "transformer_engine/transformer_engine.h"
constexpr int block_size = 512; constexpr int block_size = 512;
constexpr int ctas_per_sm = 4; constexpr int ctas_per_sm = 4;
...@@ -451,13 +449,13 @@ std::vector<py::object> fused_attn_bwd( ...@@ -451,13 +449,13 @@ std::vector<py::object> fused_attn_bwd(
nvte_tensor_pack_create(&nvte_aux_tensor_pack); nvte_tensor_pack_create(&nvte_aux_tensor_pack);
nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size(); nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size();
for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) {
const std::vector<int64_t> &signed_shape = Aux_CTX_Tensors[i].sizes().vec(); std::vector<int64_t> tmp(Aux_CTX_Tensors[i].sizes().vec());
const std::vector<size_t> tmp(signed_shape.begin(), signed_shape.end()); auto temp_vec = std::vector<size_t>(tmp.begin(), tmp.end());
const NVTEShape temp_shape = {temp_vec.data(), temp_vec.size()};
NVTEBasicTensor temp_data = { NVTEBasicTensor temp_data = {
Aux_CTX_Tensors[i].data_ptr(), Aux_CTX_Tensors[i].data_ptr(),
static_cast<NVTEDType>(GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type())), static_cast<NVTEDType>(GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type())),
nvte_make_shape(tmp.data(), tmp.size())}; temp_shape};
nvte_set_tensor_param(&nvte_aux_tensor_pack.tensors[i], kNVTERowwiseData, &temp_data); nvte_set_tensor_param(&nvte_aux_tensor_pack.tensors[i], kNVTERowwiseData, &temp_data);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment