/************************************************************************* * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ #include #include #include #include namespace transformer_engine { namespace jax { constexpr int kMaxNumDim = 8; struct Shape { int num_dim; size_t dims[kMaxNumDim]; void from_vector(const std::vector &shape); std::vector to_vector() const; }; std::vector MakeShapeVector(NVTEShape shape); } // namespace jax } // namespace transformer_engine