/************************************************************************* * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ #include #include #include #include namespace transformer_engine { namespace jax { constexpr int kMaxNumDim = 8; struct Shape { int num_dim; size_t dims[kMaxNumDim]; void from_vector(const std::vector &shape); std::vector to_vector() const; }; std::vector MakeShapeVector(NVTEShape shape); inline size_t product(const std::vector &shape) { size_t ret = 1; for (const auto &elem : shape) { ret *= elem; } return ret; } enum class QuantizeAxis { ROWWISE, COLWISE, ROWWISE_COLWISE, }; } // namespace jax } // namespace transformer_engine