Unverified Commit afb70224 authored by kwyss-nvidia's avatar kwyss-nvidia Committed by GitHub
Browse files

Kwyss/new shape owns data (#1708)

* Reapply "Allow NVTEShape to own data." (#1703)

This reverts commit 91405eb4

.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update code so that data is replaced by an array.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Specify unambiguous Tensor constructor in tests.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Fix assumption in test of 2D shape.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Remove row and col
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

---------
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>
parent 21ec6e04
...@@ -109,7 +109,7 @@ class Tensor { ...@@ -109,7 +109,7 @@ class Tensor {
const bool rowwise = true, const bool rowwise = true,
const bool columnwise = false, const bool columnwise = false,
const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING) : const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING) :
Tensor(name, NVTEShape{shape.data(), shape.size()}, type, rowwise, columnwise, mode) {} Tensor(name, nvte_make_shape(shape.data(), shape.size()), type, rowwise, columnwise, mode) {}
Tensor() {} Tensor() {}
......
...@@ -564,14 +564,17 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape, ...@@ -564,14 +564,17 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape,
void *buffer_ptr; void *buffer_ptr;
_ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true); _ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true);
if (_rank == 0) printf("!!! [UBP2P] Register UBuf %d\n", _ub_reg); if (_rank == 0) printf("!!! [UBP2P] Register UBuf %d\n", _ub_reg);
_ubuf = TensorWrapper(buffer_ptr, {buffer_shape[0] / tp_size * _num_ubuf_chunks, buffer_shape[1]}, _ubuf = TensorWrapper(
buffer_ptr,
std::vector<size_t>{buffer_shape[0] / tp_size * _num_ubuf_chunks, buffer_shape[1]},
buffer_dtype); buffer_dtype);
// Create tensor chunks for easy management // Create tensor chunks for easy management
char *ubuf_byte_ptr = reinterpret_cast<char *>(buffer_ptr); char *ubuf_byte_ptr = reinterpret_cast<char *>(buffer_ptr);
for (int i = 0; i < _num_ubuf_chunks; i++) { for (int i = 0; i < _num_ubuf_chunks; i++) {
_ubufs.push_back(TensorWrapper(reinterpret_cast<void *>(ubuf_byte_ptr), _ubufs.push_back(TensorWrapper(reinterpret_cast<void *>(ubuf_byte_ptr),
{buffer_shape[0] / tp_size, buffer_shape[1]}, buffer_dtype)); std::vector<size_t>{buffer_shape[0] / tp_size, buffer_shape[1]},
buffer_dtype));
ubuf_byte_ptr += buffer_chunk_bytes; ubuf_byte_ptr += buffer_chunk_bytes;
} }
......
...@@ -78,8 +78,8 @@ struct SimpleTensor { ...@@ -78,8 +78,8 @@ struct SimpleTensor {
SimpleTensor() : SimpleTensor(nullptr, {}, DType::kFloat32) {} SimpleTensor() : SimpleTensor(nullptr, {}, DType::kFloat32) {}
operator NVTEBasicTensor() const { operator NVTEBasicTensor() const {
const NVTEShape shape = {this->shape.data(), this->shape.size()}; return {dptr, static_cast<NVTEDType>(dtype),
return {dptr, static_cast<NVTEDType>(dtype), shape}; nvte_make_shape(this->shape.data(), this->shape.size())};
} }
int numel() const { int numel() const {
...@@ -99,11 +99,6 @@ struct Tensor { ...@@ -99,11 +99,6 @@ struct Tensor {
SimpleTensor scale_inv; SimpleTensor scale_inv;
SimpleTensor columnwise_scale_inv; SimpleTensor columnwise_scale_inv;
private:
// Used as an allocation for nvte_tensor_shape
// if the shape has to be inferred from columnwise data.
mutable std::vector<size_t> rowwise_shape_cache;
public: public:
NVTEScalingMode scaling_mode; NVTEScalingMode scaling_mode;
...@@ -194,22 +189,6 @@ struct Tensor { ...@@ -194,22 +189,6 @@ struct Tensor {
} }
} }
const std::vector<size_t> &rowwise_shape_ref() const {
auto shape_queried = shape();
// This method is primarily designed for nvte_shape.
// An unfortunate consequence of unconditionally assigning
// values to rowwise_shape_cache without a check is that
// repeated calls to rowwise_shape_ref are likely to
// invalidate the data pointers from previous calls.
// If the shape has changed, then invalidating is necessary
// in at least some cases, but we want to keep the data
// valid otherwise.
if (rowwise_shape_cache != shape_queried) {
rowwise_shape_cache = std::move(shape_queried);
}
return rowwise_shape_cache;
}
/*! Matrix height after tensor is flattened to 2D /*! Matrix height after tensor is flattened to 2D
* *
* If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted * If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted
......
...@@ -38,8 +38,8 @@ enum NVTEDType { ...@@ -38,8 +38,8 @@ enum NVTEDType {
* \brief Shape of the tensor. * \brief Shape of the tensor.
*/ */
struct NVTEShape { struct NVTEShape {
/*! \brief Shape data, of size ndim. */ /*! \brief Shape data, with ndim valid elements. */
const size_t *data; size_t data[15];
/*! \brief Number of dimensions. */ /*! \brief Number of dimensions. */
size_t ndim; size_t ndim;
}; };
...@@ -134,6 +134,15 @@ void *nvte_tensor_data(const NVTETensor tensor); ...@@ -134,6 +134,15 @@ void *nvte_tensor_data(const NVTETensor tensor);
*/ */
void *nvte_tensor_columnwise_data(const NVTETensor tensor); void *nvte_tensor_columnwise_data(const NVTETensor tensor);
/*! \brief Construct a shape from an array of dimension sizes.
*
* \param[data] Pointer to start of shape array.
* \param[data] Number of dimensions (must be <= 14)
*
* \return A shape. The shape will own its own copy of the data.
*/
NVTEShape nvte_make_shape(const size_t *data, size_t ndim);
/*! \brief Get a tensor's data shape. /*! \brief Get a tensor's data shape.
* *
* \param[in] tensor Tensor. * \param[in] tensor Tensor.
...@@ -434,8 +443,9 @@ class TensorWrapper { ...@@ -434,8 +443,9 @@ class TensorWrapper {
float *amax_dptr = nullptr, float *scale_dptr = nullptr, float *amax_dptr = nullptr, float *scale_dptr = nullptr,
float *scale_inv_dptr = nullptr, const std::vector<size_t> &scale_inv_shape = {1}, float *scale_inv_dptr = nullptr, const std::vector<size_t> &scale_inv_shape = {1},
const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING)
: TensorWrapper(dptr, NVTEShape{shape.data(), shape.size()}, dtype, amax_dptr, scale_dptr, : TensorWrapper(dptr, nvte_make_shape(shape.data(), shape.size()), dtype, amax_dptr,
scale_inv_dptr, NVTEShape{scale_inv_shape.data(), scale_inv_shape.size()}, scale_dptr, scale_inv_dptr,
nvte_make_shape(scale_inv_shape.data(), scale_inv_shape.size()),
scaling_mode) {} scaling_mode) {}
/*! \brief Constructs new empty TensorWrapper. /*! \brief Constructs new empty TensorWrapper.
...@@ -551,7 +561,9 @@ class TensorWrapper { ...@@ -551,7 +561,9 @@ class TensorWrapper {
* \return Shape of this TensorWrapper. * \return Shape of this TensorWrapper.
*/ */
const NVTEShape shape() const noexcept { const NVTEShape shape() const noexcept {
if (tensor_ == nullptr) return NVTEShape{nullptr, 0}; if (tensor_ == nullptr) {
return nvte_make_shape(nullptr, 0);
}
return nvte_tensor_shape(tensor_); return nvte_tensor_shape(tensor_);
} }
...@@ -560,7 +572,9 @@ class TensorWrapper { ...@@ -560,7 +572,9 @@ class TensorWrapper {
* \return Shape of this TensorWrapper. * \return Shape of this TensorWrapper.
*/ */
const NVTEShape columnwise_shape() const noexcept { const NVTEShape columnwise_shape() const noexcept {
if (tensor_ == nullptr) return NVTEShape{nullptr, 0}; if (tensor_ == nullptr) {
return nvte_make_shape(nullptr, 0);
}
return nvte_tensor_columnwise_shape(tensor_); return nvte_tensor_columnwise_shape(tensor_);
} }
...@@ -673,7 +687,9 @@ class TensorWrapper { ...@@ -673,7 +687,9 @@ class TensorWrapper {
* \return scale_inv_shape of this TensorWrapper. * \return scale_inv_shape of this TensorWrapper.
*/ */
const NVTEShape scale_inv_shape() const noexcept { const NVTEShape scale_inv_shape() const noexcept {
if (tensor_ == nullptr) return NVTEShape{nullptr, 0}; if (tensor_ == nullptr) {
return nvte_make_shape(nullptr, 0);
}
return nvte_tensor_scale_inv_shape(tensor_); return nvte_tensor_scale_inv_shape(tensor_);
} }
...@@ -689,12 +705,15 @@ class TensorWrapper { ...@@ -689,12 +705,15 @@ class TensorWrapper {
void zero_(cudaStream_t stream) { nvte_zero_tensor(tensor_, stream); } void zero_(cudaStream_t stream) { nvte_zero_tensor(tensor_, stream); }
static constexpr size_t defaultData = 1; static constexpr size_t defaultData = 1;
static constexpr NVTEShape defaultShape = {&defaultData, 1}; static constexpr NVTEShape defaultShape = {
{defaultData, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1};
private: private:
NVTEShape convertShape(const NVTEShape &s) { return s; } NVTEShape convertShape(const NVTEShape &s) { return s; }
NVTEShape convertShape(const std::vector<size_t> &s) { return {s.data(), s.size()}; } NVTEShape convertShape(const std::vector<size_t> &s) {
return nvte_make_shape(s.data(), s.size());
}
/*! \brief Wrapped NVTETensor. */ /*! \brief Wrapped NVTETensor. */
NVTETensor tensor_ = nullptr; NVTETensor tensor_ = nullptr;
......
...@@ -212,6 +212,20 @@ NVTEDType nvte_tensor_type(const NVTETensor tensor) { ...@@ -212,6 +212,20 @@ NVTEDType nvte_tensor_type(const NVTETensor tensor) {
reinterpret_cast<const transformer_engine::Tensor *>(tensor)->dtype()); reinterpret_cast<const transformer_engine::Tensor *>(tensor)->dtype());
} }
NVTEShape nvte_make_shape(const size_t *data, size_t ndim) {
NVTEShape ret;
if (ndim == 0) {
ret.ndim = 0;
return ret;
}
NVTE_CHECK(ndim <= sizeof(ret.data) / sizeof(ret.data[0]),
"Too many dims for NVTEShape (requested: ", ndim,
", max: ", sizeof(ret.data) / sizeof(ret.data[0]), ")");
std::copy(data, data + ndim, ret.data);
ret.ndim = ndim;
return ret;
}
NVTEShape nvte_tensor_shape(const NVTETensor tensor) { NVTEShape nvte_tensor_shape(const NVTETensor tensor) {
if (tensor == nullptr) { if (tensor == nullptr) {
NVTE_ERROR("Invalid tensor"); NVTE_ERROR("Invalid tensor");
...@@ -219,12 +233,9 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor) { ...@@ -219,12 +233,9 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor) {
// Determine tensor shape depending on tensor format // Determine tensor shape depending on tensor format
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor); const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
const std::vector<size_t> &rowwise_shape = t.rowwise_shape_ref(); std::vector<size_t> shape = t.shape();
NVTEShape ret; return nvte_make_shape(shape.data(), shape.size());
ret.data = rowwise_shape.data();
ret.ndim = rowwise_shape.size();
return ret;
} }
NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) { NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) {
...@@ -232,10 +243,7 @@ NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) { ...@@ -232,10 +243,7 @@ NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) {
NVTE_ERROR("Invalid tensor"); NVTE_ERROR("Invalid tensor");
} }
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor); const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
NVTEShape ret; return nvte_make_shape(t.columnwise_data.shape.data(), t.columnwise_data.shape.size());
ret.data = t.columnwise_data.shape.data();
ret.ndim = t.columnwise_data.shape.size();
return ret;
} }
size_t nvte_tensor_ndims(const NVTETensor tensor) { return nvte_tensor_shape(tensor).ndim; } size_t nvte_tensor_ndims(const NVTETensor tensor) { return nvte_tensor_shape(tensor).ndim; }
...@@ -303,12 +311,11 @@ void *nvte_tensor_columnwise_scale_inv(const NVTETensor tensor) { ...@@ -303,12 +311,11 @@ void *nvte_tensor_columnwise_scale_inv(const NVTETensor tensor) {
} }
NVTEShape nvte_tensor_scale_inv_shape(const NVTETensor tensor) { NVTEShape nvte_tensor_scale_inv_shape(const NVTETensor tensor) {
if (tensor == nullptr) return {nullptr, 0}; if (tensor == nullptr) {
return nvte_make_shape(nullptr, 0);
}
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor); const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
NVTEShape ret; return nvte_make_shape(t.scale_inv.shape.data(), t.scale_inv.shape.size());
ret.data = t.scale_inv.shape.data();
ret.ndim = t.scale_inv.shape.size();
return ret;
} }
void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name, void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name,
...@@ -342,7 +349,7 @@ void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name, ...@@ -342,7 +349,7 @@ void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name,
NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam param_name) { NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam param_name) {
if (tensor == nullptr) { if (tensor == nullptr) {
return {nullptr, kNVTEFloat32, {nullptr, 0}}; return {nullptr, kNVTEFloat32, nvte_make_shape(nullptr, 0)};
} }
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor); const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
switch (param_name) { switch (param_name) {
......
...@@ -3,9 +3,11 @@ ...@@ -3,9 +3,11 @@
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "extensions.h" #include "extensions.h"
#include "kv_cache.cuh" #include "kv_cache.cuh"
#include "thd_utils.cuh" #include "thd_utils.cuh"
#include "transformer_engine/transformer_engine.h"
constexpr int block_size = 512; constexpr int block_size = 512;
constexpr int ctas_per_sm = 4; constexpr int ctas_per_sm = 4;
...@@ -449,13 +451,13 @@ std::vector<py::object> fused_attn_bwd( ...@@ -449,13 +451,13 @@ std::vector<py::object> fused_attn_bwd(
nvte_tensor_pack_create(&nvte_aux_tensor_pack); nvte_tensor_pack_create(&nvte_aux_tensor_pack);
nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size(); nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size();
for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) {
std::vector<int64_t> tmp(Aux_CTX_Tensors[i].sizes().vec()); const std::vector<int64_t> &signed_shape = Aux_CTX_Tensors[i].sizes().vec();
auto temp_vec = std::vector<size_t>(tmp.begin(), tmp.end()); const std::vector<size_t> tmp(signed_shape.begin(), signed_shape.end());
const NVTEShape temp_shape = {temp_vec.data(), temp_vec.size()};
NVTEBasicTensor temp_data = { NVTEBasicTensor temp_data = {
Aux_CTX_Tensors[i].data_ptr(), Aux_CTX_Tensors[i].data_ptr(),
static_cast<NVTEDType>(GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type())), static_cast<NVTEDType>(GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type())),
temp_shape}; nvte_make_shape(tmp.data(), tmp.size())};
nvte_set_tensor_param(&nvte_aux_tensor_pack.tensors[i], kNVTERowwiseData, &temp_data); nvte_set_tensor_param(&nvte_aux_tensor_pack.tensors[i], kNVTERowwiseData, &temp_data);
} }
......
...@@ -167,8 +167,8 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans ...@@ -167,8 +167,8 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
makeTransformerEngineTensor(get_data_ptr(pre_gelu_out), gelu_shape, gelu_type); makeTransformerEngineTensor(get_data_ptr(pre_gelu_out), gelu_shape, gelu_type);
// Workspace // Workspace
auto te_workspace = auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(),
makeTransformerEngineTensor(workspace.data_ptr(), {workspaceSize}, DType::kByte); std::vector<size_t>{workspaceSize}, DType::kByte);
// Set an external SM Margin to all the GEMMs. // Set an external SM Margin to all the GEMMs.
// This comes in handy when DP is overlapped with GEMMs // This comes in handy when DP is overlapped with GEMMs
...@@ -286,12 +286,13 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine ...@@ -286,12 +286,13 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine
nvte_scaling_modeB); nvte_scaling_modeB);
// TODO: D_scale_inv cannot be nullptr when D_type is FP8. // TODO: D_scale_inv cannot be nullptr when D_type is FP8.
auto te_D = makeTransformerEngineTensor( auto te_D = makeTransformerEngineTensor(
D.data_ptr(), {static_cast<size_t>(D.size(0)), static_cast<size_t>(D.size(1))}, D_type, D.data_ptr(),
std::vector<size_t>{static_cast<size_t>(D.size(0)), static_cast<size_t>(D.size(1))}, D_type,
D_amax.data_ptr(), D_scale.data_ptr(), nullptr); D_amax.data_ptr(), D_scale.data_ptr(), nullptr);
auto te_bias = auto te_bias = makeTransformerEngineTensor(
makeTransformerEngineTensor(bias.data_ptr(), {static_cast<size_t>(bias.size(0))}, bias_type); bias.data_ptr(), std::vector<size_t>{static_cast<size_t>(bias.size(0))}, bias_type);
auto te_counter = makeTransformerEngineTensor( auto te_counter = makeTransformerEngineTensor(
counter.data_ptr(), {static_cast<size_t>(counter.size(0))}, DType::kInt32); counter.data_ptr(), std::vector<size_t>{static_cast<size_t>(counter.size(0))}, DType::kInt32);
const auto gelu_shape = pre_gelu_out.data_ptr() == nullptr const auto gelu_shape = pre_gelu_out.data_ptr() == nullptr
? std::vector<size_t>{static_cast<size_t>(pre_gelu_out.size(0))} ? std::vector<size_t>{static_cast<size_t>(pre_gelu_out.size(0))}
...@@ -299,8 +300,8 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine ...@@ -299,8 +300,8 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine
static_cast<size_t>(pre_gelu_out.size(1))}; static_cast<size_t>(pre_gelu_out.size(1))};
auto te_pre_gelu_out = makeTransformerEngineTensor( auto te_pre_gelu_out = makeTransformerEngineTensor(
pre_gelu_out.data_ptr(), gelu_shape, GetTransformerEngineDType(pre_gelu_out.scalar_type())); pre_gelu_out.data_ptr(), gelu_shape, GetTransformerEngineDType(pre_gelu_out.scalar_type()));
auto te_workspace = auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(),
makeTransformerEngineTensor(workspace.data_ptr(), {workspaceSize}, DType::kByte); std::vector<size_t>{workspaceSize}, DType::kByte);
nvte_cublas_atomic_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(), nvte_cublas_atomic_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(),
te_pre_gelu_out.data(), transa, transb, grad, te_workspace.data(), te_pre_gelu_out.data(), transa, transb, grad, te_workspace.data(),
...@@ -419,7 +420,8 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm( ...@@ -419,7 +420,8 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
wrappers.emplace_back(std::move(te_pre_gelu_out)); wrappers.emplace_back(std::move(te_pre_gelu_out));
} }
for (size_t i = 0; i < workspace.size(); i++) { for (size_t i = 0; i < workspace.size(); i++) {
auto wsp = makeTransformerEngineTensor(workspace[i].data_ptr(), {workspaceSize}, DType::kByte); auto wsp = makeTransformerEngineTensor(workspace[i].data_ptr(),
std::vector<size_t>{workspaceSize}, DType::kByte);
te_workspace_vector.emplace_back(wsp.data()); te_workspace_vector.emplace_back(wsp.data());
wrappers.emplace_back(std::move(wsp)); wrappers.emplace_back(std::move(wsp));
} }
......
...@@ -61,12 +61,16 @@ std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd( ...@@ -61,12 +61,16 @@ std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd(
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
auto input_cu = makeTransformerEngineTensor( auto input_cu = makeTransformerEngineTensor(
input.data_ptr(), {static_cast<size_t>(input.size(0)), static_cast<size_t>(num_cols)}, dtype); input.data_ptr(),
auto permuted_output_cu = makeTransformerEngineTensor( std::vector<size_t>{static_cast<size_t>(input.size(0)), static_cast<size_t>(num_cols)},
permuted_output.data_ptr(), dtype);
{static_cast<size_t>(permuted_output.size(0)), static_cast<size_t>(num_cols)}, dtype); auto permuted_output_cu =
auto sorted_row_id_cu = makeTransformerEngineTensor(permuted_output.data_ptr(),
makeTransformerEngineTensor(sorted_row_id_ptr, {static_cast<size_t>(num_tokens * topK)}, std::vector<size_t>{static_cast<size_t>(permuted_output.size(0)),
static_cast<size_t>(num_cols)},
dtype);
auto sorted_row_id_cu = makeTransformerEngineTensor(
sorted_row_id_ptr, std::vector<size_t>{static_cast<size_t>(num_tokens * topK)},
transformer_engine::DType::kInt32); transformer_engine::DType::kInt32);
auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); auto row_id_map_cu = makeTransformerEngineTensor(row_id_map);
...@@ -99,10 +103,14 @@ at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType d ...@@ -99,10 +103,14 @@ at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType d
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
auto input_cu = makeTransformerEngineTensor( auto input_cu = makeTransformerEngineTensor(
input.data_ptr(), {static_cast<size_t>(input.size(0)), static_cast<size_t>(num_cols)}, dtype); input.data_ptr(),
std::vector<size_t>{static_cast<size_t>(input.size(0)), static_cast<size_t>(num_cols)},
dtype);
auto unpermuted_output_cu = makeTransformerEngineTensor( auto unpermuted_output_cu = makeTransformerEngineTensor(
unpermuted_output.data_ptr(), unpermuted_output.data_ptr(),
{static_cast<size_t>(unpermuted_output.size(0)), static_cast<size_t>(num_cols)}, dtype); std::vector<size_t>{static_cast<size_t>(unpermuted_output.size(0)),
static_cast<size_t>(num_cols)},
dtype);
auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); auto row_id_map_cu = makeTransformerEngineTensor(row_id_map);
auto prob_cu = makeTransformerEngineTensor(prob); auto prob_cu = makeTransformerEngineTensor(prob);
...@@ -130,13 +138,16 @@ std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::T ...@@ -130,13 +138,16 @@ std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::T
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
auto input_bwd_cu = makeTransformerEngineTensor( auto input_bwd_cu = makeTransformerEngineTensor(
input_bwd.data_ptr(), {static_cast<size_t>(input_bwd.size(0)), static_cast<size_t>(num_cols)}, input_bwd.data_ptr(),
std::vector<size_t>{static_cast<size_t>(input_bwd.size(0)), static_cast<size_t>(num_cols)},
dtype); dtype);
auto act_grad_cu = makeTransformerEngineTensor( auto act_grad_cu = makeTransformerEngineTensor(
act_grad.data_ptr(), {static_cast<size_t>(act_grad.size(0)), static_cast<size_t>(num_cols)}, act_grad.data_ptr(),
std::vector<size_t>{static_cast<size_t>(act_grad.size(0)), static_cast<size_t>(num_cols)},
dtype); dtype);
auto input_fwd_cu = makeTransformerEngineTensor( auto input_fwd_cu = makeTransformerEngineTensor(
input_fwd.data_ptr(), {static_cast<size_t>(input_fwd.size(0)), static_cast<size_t>(num_cols)}, input_fwd.data_ptr(),
std::vector<size_t>{static_cast<size_t>(input_fwd.size(0)), static_cast<size_t>(num_cols)},
dtype); dtype);
auto row_id_map_cu = makeTransformerEngineTensor(row_id_map); auto row_id_map_cu = makeTransformerEngineTensor(row_id_map);
auto prob_cu = makeTransformerEngineTensor(prob); auto prob_cu = makeTransformerEngineTensor(prob);
......
...@@ -100,8 +100,8 @@ at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype, ...@@ -100,8 +100,8 @@ at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype,
} }
if (M == 0 || N == 0) return out; if (M == 0 || N == 0) return out;
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, otype); auto input_cu = makeTransformerEngineTensor(input.data_ptr(), std::vector<size_t>{M, N}, otype);
auto output_cu = makeTransformerEngineTensor(out.data_ptr(), {N, M}, otype); auto output_cu = makeTransformerEngineTensor(out.data_ptr(), std::vector<size_t>{N, M}, otype);
nvte_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); nvte_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment