misc.cpp 910 Bytes
Newer Older
1
2
3
4
5
6
/*************************************************************************
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 *
 * See LICENSE for license information.
 ************************************************************************/

7
#include "extensions.h"
8
9
10
11
12

namespace transformer_engine {
namespace jax {

std::vector<size_t> MakeShapeVector(NVTEShape shape) {
13
  return std::vector<size_t>(shape.data, shape.data + shape.ndim);
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
}

void Shape::from_vector(const std::vector<size_t> &shape) {
  num_dim = shape.size();
  assert(num_dim <= kMaxNumDim);
  std::memcpy(dims, shape.data(), num_dim * sizeof(size_t));
}

std::vector<size_t> Shape::to_vector() const {
  assert(num_dim <= kMaxNumDim);
  std::vector<size_t> shape(num_dim);
  std::memcpy(shape.data(), dims, num_dim * sizeof(size_t));
  return shape;
}

}  // namespace jax
}  // namespace transformer_engine