#ifndef CK_STATIC_TENSOR_HPP #define CK_STATIC_TENSOR_HPP namespace ck { // StaticTensor for Scalar template ::type = false> struct StaticTensor { static constexpr auto desc_ = TensorDesc{}; static constexpr index_t ndim_ = TensorDesc::GetNumOfDimension(); static constexpr index_t element_space_size_ = desc_.GetElementSpaceSize(); __host__ __device__ constexpr StaticTensor() : invalid_element_scalar_value_{0} {} __host__ __device__ constexpr StaticTensor(T invalid_element_value) : invalid_element_scalar_value_{invalid_element_value} { } // read access template ::value && Idx::Size() == ndim_, bool>::type = false> __host__ __device__ constexpr const T& operator[](Idx) const { constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{})); constexpr index_t offset = coord.GetOffset(); constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord); if constexpr(is_valid) { return data_[Number{}]; } else { if constexpr(InvalidElementUseNumericalZeroValue) { return zero_scalar_value_; } else { return invalid_element_scalar_value_; } } } // write access template ::value && Idx::Size() == ndim_, bool>::type = false> __host__ __device__ constexpr T& operator()(Idx) { constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{})); constexpr index_t offset = coord.GetOffset(); constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord); if constexpr(is_valid) { return data_(Number{}); } else { return ignored_element_scalar_; } } StaticBuffer data_; static constexpr T zero_scalar_value_ = T{0}; const T invalid_element_scalar_value_; T ignored_element_scalar_; }; // StaticTensor for vector template ::type = false> struct StaticTensorTupleOfVectorBuffer { static constexpr auto desc_ = TensorDesc{}; static constexpr index_t ndim_ = TensorDesc::GetNumOfDimension(); static constexpr index_t element_space_size_ = desc_.GetElementSpaceSize(); static constexpr index_t num_of_vector_ = math::integer_divide_ceil(element_space_size_, ScalarPerVector); using V = vector_type; __host__ __device__ constexpr StaticTensorTupleOfVectorBuffer() : invalid_element_scalar_value_{0} { } __host__ __device__ constexpr StaticTensorTupleOfVectorBuffer(S invalid_element_value) : invalid_element_scalar_value_{invalid_element_value} { } // Get S // Idx is for S, not V template ::value && Idx::Size() == ndim_, bool>::type = false> __host__ __device__ constexpr const S& operator[](Idx) const { constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{})); constexpr index_t offset = coord.GetOffset(); constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord); if constexpr(is_valid) { return data_[Number{}]; } else { if constexpr(InvalidElementUseNumericalZeroValue) { return zero_scalar_value_; } else { return invalid_element_scalar_value_; } } } // Set S // Idx is for S, not V template ::value && Idx::Size() == ndim_, bool>::type = false> __host__ __device__ constexpr S& operator()(Idx) { constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{})); constexpr index_t offset = coord.GetOffset(); constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord); if constexpr(is_valid) { return data_(Number{}); } else { return ignored_element_scalar_; } } // Get X // Idx is for S, not X. Idx should be aligned with X template ::value && is_known_at_compile_time::value && Idx::Size() == ndim_, bool>::type = false> __host__ __device__ constexpr X GetAsType(Idx) const { constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{})); constexpr index_t offset = coord.GetOffset(); constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord); if constexpr(is_valid) { return data_.template GetAsType(Number{}); } else { if constexpr(InvalidElementUseNumericalZeroValue) { // TODO: is this right way to initialize a vector? return X{0}; } else { // TODO: is this right way to initialize a vector? return X{invalid_element_scalar_value_}; } } } // Set X // Idx is for S, not X. Idx should be aligned with X template ::value && is_known_at_compile_time::value && Idx::Size() == ndim_, bool>::type = false> __host__ __device__ constexpr void SetAsType(Idx, X x) { constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{})); constexpr index_t offset = coord.GetOffset(); constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord); if constexpr(is_valid) { data_.template SetAsType(Number{}, x); } } // Get read access to V. No is_valid check // Idx is for S, not V. Idx should be aligned with V template __host__ __device__ constexpr const V& GetVectorTypeReference(Idx) const { constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{})); constexpr index_t offset = coord.GetOffset(); return data_.GetVectorTypeReference(Number{}); } // Get read access to V. No is_valid check // Idx is for S, not V. Idx should be aligned with V template __host__ __device__ constexpr V& GetVectorTypeReference(Idx) { constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{})); constexpr index_t offset = coord.GetOffset(); return data_.GetVectorTypeReference(Number{}); } StaticBufferTupleOfVector data_; static constexpr S zero_scalar_value_ = S{0}; const S invalid_element_scalar_value_ = S{0}; S ignored_element_scalar_; }; template ::type = false> __host__ __device__ constexpr auto make_static_tensor(TensorDesc) { return StaticTensor{}; } template < AddressSpaceEnum AddressSpace, typename T, typename TensorDesc, typename X, typename enable_if::type = false, typename enable_if, remove_cvref_t>::value, bool>::type = false> __host__ __device__ constexpr auto make_static_tensor(TensorDesc, X invalid_element_value) { return StaticTensor{invalid_element_value}; } } // namespace ck #endif