/* * Copyright (c) OpenMMLab. All rights reserved. * Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Modified from // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/triton_backend/multi_gpu_gpt/ParallelGptTritonModel.h #pragma once #include "src/fastertransformer/models/llama/LlamaV2.h" #include "src/fastertransformer/triton_backend/llama/LlamaTritonModelInstance.h" #include "src/fastertransformer/triton_backend/transformer_triton_backend.hpp" #include "src/fastertransformer/utils/cuda_utils.h" #include "src/fastertransformer/utils/custom_ar_comm.h" #include "src/fastertransformer/utils/nccl_utils.h" #include #include namespace ft = fastertransformer; template struct LlamaTritonSharedModelInstance; template struct LlamaTritonModel: public AbstractTransformerModel { LlamaTritonModel(size_t tensor_para_size, size_t pipeline_para_size, int enable_custom_all_reduce, std::string model_dir); ~LlamaTritonModel() = default; std::unique_ptr createModelInstance(int deviceId, int rank, cudaStream_t stream, std::pair, std::vector> nccl_params, std::shared_ptr custom_all_reduce_comm = nullptr) override; void createSharedWeights(int deviceId, int rank) override; void createCustomComms(std::vector>* custom_all_reduce_comms, int world_size) override; std::pair, std::vector> createNcclParams(const int node_id, const int device_id_start, const bool multi_node) override; std::unique_ptr createInstanceComm(int size) override; void handleMissingParams(); std::string toString() override; int getTensorParaSize() override; int getPipelineParaSize() override; private: std::unique_ptr> createSharedModelInstance(int deviceId, int rank, std::pair, std::vector> nccl_params, std::shared_ptr custom_all_reduce_comm = nullptr); size_t head_num_; size_t size_per_head_; size_t inter_size_; size_t num_layer_; size_t vocab_size_; size_t rotary_embedding_dim_; float norm_eps_; int max_batch_size_; int max_context_token_num_; int session_len_; int step_length_; int start_id_; int end_id_; int cache_max_entry_count_; int cache_chunk_size_; int use_context_fmha_; size_t tensor_para_size_; size_t pipeline_para_size_; ft::WeightType weight_type_; bool attn_bias_; int quant_policy_; size_t prefix_cache_len_{}; // shared weights for each device std::vector>> shared_weights_; std::shared_ptr::SharedState> shared_state_; // weak_ptr is used so that the instances get released when all strong references are gone std::vector>> shared_instances_; std::deque shared_mutexes_; // is locking really needed? // // residual type // bool use_gptj_residual_ = true; // // number of tasks (for prefix-prompt, p/prompt-tuning) // size_t num_tasks_ = 0; // int prompt_learning_start_id_ = 0; // ft::PromptLearningType prompt_learning_type_ = ft::PromptLearningType::no_prompt; // std::map> prompt_learning_table_pair_ = {}; bool is_fp16_; int enable_custom_all_reduce_ = 0; std::string model_name_; std::string model_dir_; };