/************************************************************************* * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ #include "extensions.h" namespace transformer_engine { namespace jax { std::vector MakeShapeVector(NVTEShape shape) { return std::vector(shape.data, shape.data + shape.ndim); } void Shape::from_vector(const std::vector &shape) { num_dim = shape.size(); assert(num_dim <= kMaxNumDim); std::memcpy(dims, shape.data(), num_dim * sizeof(size_t)); } std::vector Shape::to_vector() const { assert(num_dim <= kMaxNumDim); std::vector shape(num_dim); std::memcpy(shape.data(), dims, num_dim * sizeof(size_t)); return shape; } } // namespace jax } // namespace transformer_engine