/* * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include "src/fastertransformer/utils/cuda_bf16_wrapper.h" #include "src/fastertransformer/utils/cuda_fp8_utils.h" #include "src/fastertransformer/utils/cuda_utils.h" #include "src/fastertransformer/utils/string_utils.h" #include "stdlib.h" #include #include #include #include #include #include #include #include #include #include namespace fastertransformer { typedef enum datatype_enum { TYPE_INVALID, TYPE_BOOL, TYPE_UINT8, TYPE_UINT16, TYPE_UINT32, TYPE_UINT64, TYPE_INT8, TYPE_INT16, TYPE_INT32, TYPE_INT64, TYPE_FP16, TYPE_FP32, TYPE_FP64, TYPE_BYTES, TYPE_BF16, TYPE_FP8_E4M3, TYPE_STR, TYPE_VOID, } DataType; template DataType getTensorType() { if (std::is_same::value || std::is_same::value) { return TYPE_FP32; } else if (std::is_same::value || std::is_same::value) { return TYPE_FP16; } #ifdef ENABLE_BF16 else if (std::is_same::value || std::is_same::value) { return TYPE_BF16; } #endif #ifdef ENABLE_FP8 else if (std::is_same::value || std::is_same::value) { return TYPE_FP8_E4M3; } #endif else if (std::is_same::value || std::is_same::value) { return TYPE_INT32; } else if (std::is_same::value || std::is_same::value) { return TYPE_INT8; } else if (std::is_same::value || std::is_same::value) { return TYPE_UINT32; } else if (std::is_same::value || std::is_same::value) { return TYPE_UINT64; } else if (std::is_same::value || std::is_same::value) { return TYPE_BOOL; } else if (std::is_same::value || std::is_same::value) { return TYPE_BYTES; } else { return TYPE_INVALID; } } typedef enum memorytype_enum { MEMORY_CPU, MEMORY_CPU_PINNED, MEMORY_GPU } MemoryType; struct Tensor { const MemoryType where; const DataType type; const std::vector shape; const void* data; // TODO(bhseuh) modify from const void* to void* const const std::vector offsets = std::vector{}; Tensor(); Tensor(const MemoryType _where, const DataType _type, const std::vector _shape, const void* _data); Tensor(const MemoryType _where, const DataType _type, const std::vector _shape, const void* _data, const std::vector _offset); size_t size() const; size_t sizeBytes() const; std::string whereToString() const; std::string toString() const; std::string getNumpyTypeDesc(DataType type) const; void saveNpy(const std::string& filename) const; static Tensor loadNpy(const std::string& npy_file, const MemoryType where); static DataType typeFromNumpyDesc(std::string type); static size_t getTypeSize(DataType type); template inline T getVal(size_t index) const { FT_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); FT_CHECK(where == MEMORY_CPU); FT_CHECK(data != nullptr); FT_CHECK_WITH_INFO(index < size(), "index is larger than buffer size"); if (getTensorType() != type) { FT_LOG_DEBUG("getVal with type %s, but data type is: %s", getNumpyTypeDesc(getTensorType()).c_str(), getNumpyTypeDesc(type).c_str()); } return ((T*)data)[index]; } template inline T getVal() const { FT_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); if (getTensorType() != type) { FT_LOG_DEBUG("getVal with type %s, but data type is: %s", getNumpyTypeDesc(getTensorType()).c_str(), getNumpyTypeDesc(type).c_str()); } return getVal(0); } template inline T* getPtr() const { FT_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); if (getTensorType() != type) { FT_LOG_DEBUG("getPtr with type %s, but data type is: %s", getNumpyTypeDesc(getTensorType()).c_str(), getNumpyTypeDesc(type).c_str()); } return (T*)data; } inline void* getPtrWithOffset(size_t offset) const { FT_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); if (data == nullptr) { return (void*)data; } else { FT_CHECK_WITH_INFO(offset < size(), "offset is larger than buffer size"); return (void*)((char*)data + offset * Tensor::getTypeSize(type)); } } template inline T* getPtrWithOffset(size_t offset) const { FT_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); if (getTensorType() != type) { FT_LOG_DEBUG("getVal with type %s, but data type is: %s", getNumpyTypeDesc(getTensorType()).c_str(), getNumpyTypeDesc(type).c_str()); } if (data == nullptr) { return (T*)data; } else { FT_CHECK_WITH_INFO(offset < size(), fmtstr("offset (%lu) is larger than buffer size (%lu)", offset, size())); return ((T*)data) + offset; } } template T max() const { if (getTensorType() != type) { FT_LOG_DEBUG("getVal with type %s, but data type is: %s", getNumpyTypeDesc(getTensorType()).c_str(), getNumpyTypeDesc(type).c_str()); } FT_CHECK_WITH_INFO(shape.size() > 0 && data != nullptr, "Should be a non-empty tensor."); FT_CHECK_WITH_INFO(where == MEMORY_CPU || where == MEMORY_CPU_PINNED, "max() supports MEMORY_CPU or MEMORY_CPU_PINNED tensor."); size_t max_idx = 0; T max_val = getVal(max_idx); for (size_t i = 1; i < size(); ++i) { T val = getVal(i); if (val > max_val) { max_idx = i; max_val = val; } } return max_val; } template T min() const { if (getTensorType() != type) { FT_LOG_DEBUG("getVal with type %s, but data type is: %s", getNumpyTypeDesc(getTensorType()).c_str(), getNumpyTypeDesc(type).c_str()); } FT_CHECK_WITH_INFO(shape.size() > 0 && data != nullptr, "Should be a non-empty tensor."); FT_CHECK_WITH_INFO(where == MEMORY_CPU || where == MEMORY_CPU_PINNED, "min() supports MEMORY_CPU or MEMORY_CPU_PINNED tensor."); size_t min_idx = 0; T min_val = getVal(min_idx); for (size_t i = 1; i < size(); ++i) { T val = getVal(i); if (val < min_val) { min_idx = i; min_val = val; } } return min_val; } template T any(T val) const { if (getTensorType() != type) { FT_LOG_DEBUG("getVal with type %s, but data type is: %s", getNumpyTypeDesc(getTensorType()).c_str(), getNumpyTypeDesc(type).c_str()); } FT_CHECK_WITH_INFO(shape.size() > 0 && data != nullptr, "Should be a non-empty tensor."); FT_CHECK_WITH_INFO(where == MEMORY_CPU || where == MEMORY_CPU_PINNED, "any() supports MEMORY_CPU or MEMORY_CPU_PINNED tensor."); for (size_t i = 0; i < size(); ++i) { if (getVal(i) == val) { return true; } } return false; } template T all(T val) const { if (getTensorType() != type) { FT_LOG_DEBUG("getVal with type %s, but data type is: %s", getNumpyTypeDesc(getTensorType()).c_str(), getNumpyTypeDesc(type).c_str()); } FT_CHECK_WITH_INFO(shape.size() > 0 && data != nullptr, "Should be a non-empty tensor."); FT_CHECK_WITH_INFO(where == MEMORY_CPU || where == MEMORY_CPU_PINNED, "all() supports MEMORY_CPU or MEMORY_CPU_PINNED tensor."); for (size_t i = 0; i < size(); ++i) { if (getVal(i) != val) { return false; } } return true; } void updateShape(size_t idx, size_t val) { // TODO: find a better way to update the shape std::vector& shape_ref = const_cast&>(shape); shape_ref[idx] = val; } Tensor slice(std::vector shape, size_t offset = 0) const; private: static void parseNpyIntro(FILE*& f_ptr, uint32_t& header_len, uint32_t& start_data); static int parseNpyHeader(FILE*& f_ptr, uint32_t header_len, DataType& type, std::vector& shape); }; class TensorMap { private: std::unordered_map tensor_map_; inline bool isValid(const Tensor& tensor) { return tensor.size() > 0 && tensor.data != nullptr; } public: TensorMap() = default; TensorMap(const std::unordered_map& tensor_map); TensorMap(const std::vector& tensor_map); TensorMap(std::initializer_list> tensor_map); ~TensorMap(); inline size_t size() const { return tensor_map_.size(); } inline bool isExist(const std::string& key) const { FT_LOG_DEBUG("%s for key: %s", __PRETTY_FUNCTION__, key.c_str()); return tensor_map_.find(key) != tensor_map_.end(); } std::vector keys() const; inline void insert(const std::string& key, const Tensor& value) { FT_CHECK_WITH_INFO(!isExist(key), fmtstr("Duplicated key %s", key.c_str())); FT_CHECK_WITH_INFO(isValid(value), fmtstr("A none tensor or nullptr is not allowed (key is %s)", key.c_str())); tensor_map_.insert({key, value}); } inline void insertIfValid(const std::string& key, const Tensor& value) { if (isValid(value)) { insert({key, value}); } } inline void insert(std::pair p) { tensor_map_.insert(p); } // prevent converting int or size_t to string automatically Tensor at(int tmp) = delete; Tensor at(size_t tmp) = delete; inline Tensor& at(const std::string& key) { FT_LOG_DEBUG("%s for key %s", __PRETTY_FUNCTION__, key.c_str()); FT_CHECK_WITH_INFO(isExist(key), fmtstr("Cannot find a tensor of name %s in the tensor map (keys: %s)", key.c_str(), vec2str(keys()).c_str())); return tensor_map_.at(key); } inline Tensor at(const std::string& key) const { FT_CHECK_WITH_INFO(isExist(key), fmtstr("Cannot find a tensor of name %s in the tensor map (keys: %s)", key.c_str(), vec2str(keys()).c_str())); return tensor_map_.at(key); } inline Tensor& at(const std::string& key, Tensor& default_tensor) { FT_LOG_DEBUG("%s for key %s", __PRETTY_FUNCTION__, key.c_str()); if (isExist(key)) { return tensor_map_.at(key); } return default_tensor; } inline Tensor at(const std::string& key, Tensor& default_tensor) const { FT_LOG_DEBUG("%s for key %s", __PRETTY_FUNCTION__, key.c_str()); if (isExist(key)) { return tensor_map_.at(key); } return default_tensor; } inline Tensor& at(const std::string& key, Tensor&& default_tensor) { FT_LOG_DEBUG("%s for key %s", __PRETTY_FUNCTION__, key.c_str()); if (isExist(key)) { return tensor_map_.at(key); } return default_tensor; } inline Tensor at(const std::string& key, Tensor&& default_tensor) const { if (isExist(key)) { return tensor_map_.at(key); } return default_tensor; } template inline T getVal(const std::string& key) const { FT_CHECK_WITH_INFO(isExist(key), fmtstr("Cannot find a tensor of name %s in the tensor map (keys: %s)", key.c_str(), vec2str(keys()).c_str())); return tensor_map_.at(key).getVal(); } template inline T getVal(const std::string& key, T default_value) const { if (isExist(key)) { return tensor_map_.at(key).getVal(); } return default_value; } template inline T getValWithOffset(const std::string& key, size_t index) const { FT_CHECK_WITH_INFO(isExist(key), fmtstr("Cannot find a tensor of name %s in the tensor map (keys: %s)", key.c_str(), vec2str(keys()).c_str())); return tensor_map_.at(key).getVal(index); } template inline T getValWithOffset(const std::string& key, size_t index, T default_value) const { if (isExist(key)) { return tensor_map_.at(key).getVal(index); } return default_value; } template inline T* getPtr(const std::string& key) const { FT_CHECK_WITH_INFO(isExist(key), fmtstr("Cannot find a tensor of name %s in the tensor map (keys: %s)", key.c_str(), vec2str(keys()).c_str())); return tensor_map_.at(key).getPtr(); } template inline T* getPtr(const std::string& key, T* default_ptr) const { if (isExist(key)) { return tensor_map_.at(key).getPtr(); } return default_ptr; } template inline T* getPtrWithOffset(const std::string& key, size_t index) const { FT_CHECK_WITH_INFO(isExist(key), fmtstr("Cannot find a tensor of name %s in the tensor map (keys: %s)", key.c_str(), vec2str(keys()).c_str())); return tensor_map_.at(key).getPtrWithOffset(index); } template inline T* getPtrWithOffset(const std::string& key, size_t index, T* default_ptr) const { if (isExist(key)) { return tensor_map_.at(key).getPtrWithOffset(index); } return default_ptr; } inline std::unordered_map getMap() const { return tensor_map_; } inline std::unordered_map::iterator begin() { return tensor_map_.begin(); } inline std::unordered_map::iterator end() { return tensor_map_.end(); } inline std::unordered_map& get() { return tensor_map_; } inline std::unordered_map::const_iterator begin() const { return tensor_map_.begin(); } inline std::unordered_map::const_iterator end() const { return tensor_map_.end(); } std::string toString(); static TensorMap fromNpyFolder(const std::string& base_folder); void saveNpy(const std::string& base_folder); }; } // namespace fastertransformer