Unverified Commit e61ce77c authored by kwyss-nvidia's avatar kwyss-nvidia Committed by GitHub
Browse files

Allow NVTEShape to own data. (#1674)



* Allow NVTEShape to own data.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Convert repeated copy paths to nvte_make_shape calls.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Apply suggestions from code review
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Build fixes.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* MR feedback.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

---------
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 61f1bf6f
...@@ -112,8 +112,8 @@ struct scale_inv_meta { ...@@ -112,8 +112,8 @@ struct scale_inv_meta {
size_t type_size; size_t type_size;
}; };
NVTEShape convertShape(const std::vector<size_t>& shape) { NVTEShape convertShape(const std::vector<size_t>& s) {
return {shape.data(), shape.size()}; return nvte_make_shape(s.data(), s.size());
} }
std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape, std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
...@@ -240,7 +240,7 @@ Tensor::Tensor(const std::string& name, ...@@ -240,7 +240,7 @@ Tensor::Tensor(const std::string& name,
std::vector<size_t> normalized_shape_v = {product(shape, 0, shape.ndim - 1), std::vector<size_t> normalized_shape_v = {product(shape, 0, shape.ndim - 1),
shape.data[shape.ndim - 1]}; shape.data[shape.ndim - 1]};
NVTEShape normalized_shape = convertShape(normalized_shape_v); NVTEShape normalized_shape = convertShape(normalized_shape_v);
NVTEShape columnwise_shape{nullptr, 0}; NVTEShape columnwise_shape = {};
std::vector<size_t> columnwise_shape_vec; std::vector<size_t> columnwise_shape_vec;
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING || scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) { if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING || scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) {
...@@ -257,8 +257,7 @@ Tensor::Tensor(const std::string& name, ...@@ -257,8 +257,7 @@ Tensor::Tensor(const std::string& name,
} }
if (columnwise) { if (columnwise) {
columnwise_shape.data = columnwise_shape_vec.data(); columnwise_shape = nvte_make_shape(columnwise_shape_vec.data(), columnwise_shape_vec.size());
columnwise_shape.ndim = columnwise_shape_vec.size();
} }
tensor_ = TensorWrapper(scaling_mode); tensor_ = TensorWrapper(scaling_mode);
......
...@@ -109,7 +109,7 @@ class Tensor { ...@@ -109,7 +109,7 @@ class Tensor {
const bool rowwise = true, const bool rowwise = true,
const bool columnwise = false, const bool columnwise = false,
const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING) : const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING) :
Tensor(name, NVTEShape{shape.data(), shape.size()}, type, rowwise, columnwise, mode) {} Tensor(name, nvte_make_shape(shape.data(), shape.size()), type, rowwise, columnwise, mode) {}
Tensor() {} Tensor() {}
......
...@@ -78,8 +78,8 @@ struct SimpleTensor { ...@@ -78,8 +78,8 @@ struct SimpleTensor {
SimpleTensor() : SimpleTensor(nullptr, {}, DType::kFloat32) {} SimpleTensor() : SimpleTensor(nullptr, {}, DType::kFloat32) {}
operator NVTEBasicTensor() const { operator NVTEBasicTensor() const {
const NVTEShape shape = {this->shape.data(), this->shape.size()}; return {dptr, static_cast<NVTEDType>(dtype),
return {dptr, static_cast<NVTEDType>(dtype), shape}; nvte_make_shape(this->shape.data(), this->shape.size())};
} }
int numel() const { int numel() const {
...@@ -99,11 +99,6 @@ struct Tensor { ...@@ -99,11 +99,6 @@ struct Tensor {
SimpleTensor scale_inv; SimpleTensor scale_inv;
SimpleTensor columnwise_scale_inv; SimpleTensor columnwise_scale_inv;
private:
// Used as an allocation for nvte_tensor_shape
// if the shape has to be inferred from columnwise data.
mutable std::vector<size_t> rowwise_shape_cache;
public: public:
NVTEScalingMode scaling_mode; NVTEScalingMode scaling_mode;
...@@ -194,22 +189,6 @@ struct Tensor { ...@@ -194,22 +189,6 @@ struct Tensor {
} }
} }
const std::vector<size_t> &rowwise_shape_ref() const {
auto shape_queried = shape();
// This method is primarily designed for nvte_shape.
// An unfortunate consequence of unconditionally assigning
// values to rowwise_shape_cache without a check is that
// repeated calls to rowwise_shape_ref are likely to
// invalidate the data pointers from previous calls.
// If the shape has changed, then invalidating is necessary
// in at least some cases, but we want to keep the data
// valid otherwise.
if (rowwise_shape_cache != shape_queried) {
rowwise_shape_cache = std::move(shape_queried);
}
return rowwise_shape_cache;
}
/*! Matrix height after tensor is flattened to 2D /*! Matrix height after tensor is flattened to 2D
* *
* If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted * If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted
......
...@@ -42,6 +42,8 @@ struct NVTEShape { ...@@ -42,6 +42,8 @@ struct NVTEShape {
const size_t *data; const size_t *data;
/*! \brief Number of dimensions. */ /*! \brief Number of dimensions. */
size_t ndim; size_t ndim;
/*! \brief Copy of data. Num dims limited to permit fixed struct size.*/
size_t owned_data[14];
}; };
/*! \struct NVTEBasicTensor /*! \struct NVTEBasicTensor
...@@ -134,6 +136,15 @@ void *nvte_tensor_data(const NVTETensor tensor); ...@@ -134,6 +136,15 @@ void *nvte_tensor_data(const NVTETensor tensor);
*/ */
void *nvte_tensor_columnwise_data(const NVTETensor tensor); void *nvte_tensor_columnwise_data(const NVTETensor tensor);
/*! \brief Construct a shape from an array of dimension sizes.
*
* \param[data] Pointer to start of shape array.
* \param[data] Number of dimensions (must be <= 14)
*
* \return A shape. The shape will own its own copy of the data.
*/
NVTEShape nvte_make_shape(const size_t *data, size_t ndim);
/*! \brief Get a tensor's data shape. /*! \brief Get a tensor's data shape.
* *
* \param[in] tensor Tensor. * \param[in] tensor Tensor.
...@@ -417,8 +428,9 @@ class TensorWrapper { ...@@ -417,8 +428,9 @@ class TensorWrapper {
float *amax_dptr = nullptr, float *scale_dptr = nullptr, float *amax_dptr = nullptr, float *scale_dptr = nullptr,
float *scale_inv_dptr = nullptr, const std::vector<size_t> &scale_inv_shape = {1}, float *scale_inv_dptr = nullptr, const std::vector<size_t> &scale_inv_shape = {1},
const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING)
: TensorWrapper(dptr, NVTEShape{shape.data(), shape.size()}, dtype, amax_dptr, scale_dptr, : TensorWrapper(dptr, nvte_make_shape(shape.data(), shape.size()), dtype, amax_dptr,
scale_inv_dptr, NVTEShape{scale_inv_shape.data(), scale_inv_shape.size()}, scale_dptr, scale_inv_dptr,
nvte_make_shape(scale_inv_shape.data(), scale_inv_shape.size()),
scaling_mode) {} scaling_mode) {}
/*! \brief Constructs new empty TensorWrapper. /*! \brief Constructs new empty TensorWrapper.
...@@ -534,7 +546,9 @@ class TensorWrapper { ...@@ -534,7 +546,9 @@ class TensorWrapper {
* \return Shape of this TensorWrapper. * \return Shape of this TensorWrapper.
*/ */
const NVTEShape shape() const noexcept { const NVTEShape shape() const noexcept {
if (tensor_ == nullptr) return NVTEShape{nullptr, 0}; if (tensor_ == nullptr) {
return nvte_make_shape(nullptr, 0);
}
return nvte_tensor_shape(tensor_); return nvte_tensor_shape(tensor_);
} }
...@@ -543,7 +557,9 @@ class TensorWrapper { ...@@ -543,7 +557,9 @@ class TensorWrapper {
* \return Shape of this TensorWrapper. * \return Shape of this TensorWrapper.
*/ */
const NVTEShape columnwise_shape() const noexcept { const NVTEShape columnwise_shape() const noexcept {
if (tensor_ == nullptr) return NVTEShape{nullptr, 0}; if (tensor_ == nullptr) {
return nvte_make_shape(nullptr, 0);
}
return nvte_tensor_columnwise_shape(tensor_); return nvte_tensor_columnwise_shape(tensor_);
} }
...@@ -656,7 +672,9 @@ class TensorWrapper { ...@@ -656,7 +672,9 @@ class TensorWrapper {
* \return scale_inv_shape of this TensorWrapper. * \return scale_inv_shape of this TensorWrapper.
*/ */
const NVTEShape scale_inv_shape() const noexcept { const NVTEShape scale_inv_shape() const noexcept {
if (tensor_ == nullptr) return NVTEShape{nullptr, 0}; if (tensor_ == nullptr) {
return nvte_make_shape(nullptr, 0);
}
return nvte_tensor_scale_inv_shape(tensor_); return nvte_tensor_scale_inv_shape(tensor_);
} }
...@@ -672,12 +690,20 @@ class TensorWrapper { ...@@ -672,12 +690,20 @@ class TensorWrapper {
void zero_(cudaStream_t stream) { nvte_zero_tensor(tensor_, stream); } void zero_(cudaStream_t stream) { nvte_zero_tensor(tensor_, stream); }
static constexpr size_t defaultData = 1; static constexpr size_t defaultData = 1;
static constexpr NVTEShape defaultShape = {&defaultData, 1}; static constexpr NVTEShape defaultShape = {
&defaultData, 1, {defaultData, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}};
private: private:
NVTEShape convertShape(const NVTEShape &s) { return s; } NVTEShape convertShape(const NVTEShape &s) {
NVTEShape ret = s;
// Move the ownership rather than pointing to the parent shape.
ret.data = ret.owned_data;
return ret;
}
NVTEShape convertShape(const std::vector<size_t> &s) { return {s.data(), s.size()}; } NVTEShape convertShape(const std::vector<size_t> &s) {
return nvte_make_shape(s.data(), s.size());
}
/*! \brief Wrapped NVTETensor. */ /*! \brief Wrapped NVTETensor. */
NVTETensor tensor_ = nullptr; NVTETensor tensor_ = nullptr;
......
...@@ -211,6 +211,22 @@ NVTEDType nvte_tensor_type(const NVTETensor tensor) { ...@@ -211,6 +211,22 @@ NVTEDType nvte_tensor_type(const NVTETensor tensor) {
reinterpret_cast<const transformer_engine::Tensor *>(tensor)->dtype()); reinterpret_cast<const transformer_engine::Tensor *>(tensor)->dtype());
} }
NVTEShape nvte_make_shape(const size_t *data, size_t ndim) {
NVTEShape ret;
if (ndim == 0) {
ret.data = nullptr;
ret.ndim = 0;
return ret;
}
NVTE_CHECK(ndim <= sizeof(ret.owned_data) / sizeof(ret.owned_data[0]),
"Too many dims for NVTEShape (requested: ", ndim,
", max: ", sizeof(ret.owned_data) / sizeof(ret.owned_data[0]), ")");
std::copy(data, data + ndim, ret.owned_data);
ret.data = ret.owned_data;
ret.ndim = ndim;
return ret;
}
NVTEShape nvte_tensor_shape(const NVTETensor tensor) { NVTEShape nvte_tensor_shape(const NVTETensor tensor) {
if (tensor == nullptr) { if (tensor == nullptr) {
NVTE_ERROR("Invalid tensor"); NVTE_ERROR("Invalid tensor");
...@@ -218,12 +234,9 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor) { ...@@ -218,12 +234,9 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor) {
// Determine tensor shape depending on tensor format // Determine tensor shape depending on tensor format
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor); const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
const std::vector<size_t> &rowwise_shape = t.rowwise_shape_ref(); std::vector<size_t> shape = t.shape();
NVTEShape ret; return nvte_make_shape(shape.data(), shape.size());
ret.data = rowwise_shape.data();
ret.ndim = rowwise_shape.size();
return ret;
} }
NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) { NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) {
...@@ -231,10 +244,7 @@ NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) { ...@@ -231,10 +244,7 @@ NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) {
NVTE_ERROR("Invalid tensor"); NVTE_ERROR("Invalid tensor");
} }
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor); const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
NVTEShape ret; return nvte_make_shape(t.columnwise_data.shape.data(), t.columnwise_data.shape.size());
ret.data = t.columnwise_data.shape.data();
ret.ndim = t.columnwise_data.shape.size();
return ret;
} }
size_t nvte_tensor_ndims(const NVTETensor tensor) { return nvte_tensor_shape(tensor).ndim; } size_t nvte_tensor_ndims(const NVTETensor tensor) { return nvte_tensor_shape(tensor).ndim; }
...@@ -302,12 +312,11 @@ void *nvte_tensor_columnwise_scale_inv(const NVTETensor tensor) { ...@@ -302,12 +312,11 @@ void *nvte_tensor_columnwise_scale_inv(const NVTETensor tensor) {
} }
NVTEShape nvte_tensor_scale_inv_shape(const NVTETensor tensor) { NVTEShape nvte_tensor_scale_inv_shape(const NVTETensor tensor) {
if (tensor == nullptr) return {nullptr, 0}; if (tensor == nullptr) {
return nvte_make_shape(nullptr, 0);
}
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor); const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
NVTEShape ret; return nvte_make_shape(t.scale_inv.shape.data(), t.scale_inv.shape.size());
ret.data = t.scale_inv.shape.data();
ret.ndim = t.scale_inv.shape.size();
return ret;
} }
void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name, void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name,
......
...@@ -3,9 +3,11 @@ ...@@ -3,9 +3,11 @@
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "extensions.h" #include "extensions.h"
#include "kv_cache.cuh" #include "kv_cache.cuh"
#include "thd_utils.cuh" #include "thd_utils.cuh"
#include "transformer_engine/transformer_engine.h"
constexpr int block_size = 512; constexpr int block_size = 512;
constexpr int ctas_per_sm = 4; constexpr int ctas_per_sm = 4;
...@@ -449,13 +451,13 @@ std::vector<py::object> fused_attn_bwd( ...@@ -449,13 +451,13 @@ std::vector<py::object> fused_attn_bwd(
nvte_tensor_pack_create(&nvte_aux_tensor_pack); nvte_tensor_pack_create(&nvte_aux_tensor_pack);
nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size(); nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size();
for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) {
std::vector<int64_t> tmp(Aux_CTX_Tensors[i].sizes().vec()); const std::vector<int64_t> &signed_shape = Aux_CTX_Tensors[i].sizes().vec();
auto temp_vec = std::vector<size_t>(tmp.begin(), tmp.end()); const std::vector<size_t> tmp(signed_shape.begin(), signed_shape.end());
const NVTEShape temp_shape = {temp_vec.data(), temp_vec.size()};
NVTEBasicTensor temp_data = { NVTEBasicTensor temp_data = {
Aux_CTX_Tensors[i].data_ptr(), Aux_CTX_Tensors[i].data_ptr(),
static_cast<NVTEDType>(GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type())), static_cast<NVTEDType>(GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type())),
temp_shape}; nvte_make_shape(tmp.data(), tmp.size())};
nvte_set_tensor_param(&nvte_aux_tensor_pack.tensors[i], kNVTERowwiseData, &temp_data); nvte_set_tensor_param(&nvte_aux_tensor_pack.tensors[i], kNVTERowwiseData, &temp_data);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment