/* * Copyright (c) OpenMMLab. All rights reserved. * Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Modified from // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/triton_backend/multi_gpu_gpt/ParallelGptTritonModel.h #include "src/fastertransformer/triton_backend/llama/LlamaTritonModelInstance.h" #include "src/fastertransformer/triton_backend/transformer_triton_backend.hpp" #include "src/fastertransformer/triton_backend/triton_utils.hpp" #include "src/fastertransformer/utils/Tensor.h" #include "src/fastertransformer/utils/cuda_utils.h" #include #include #include #include #include #include namespace ft = fastertransformer; template void triton_stream_callback(std::unordered_map* output_tensors, void* ctx) { LlamaTritonModelInstance* model = reinterpret_cast*>(ctx); auto result = LlamaTritonModelInstance::convert_outputs(*output_tensors); model->stream_cb_(result, model->stream_ctx_); } template LlamaTritonModelInstance::LlamaTritonModelInstance( std::shared_ptr> instance, std::unique_ptr> allocator): instance_(std::move(instance)), allocator_(std::move(allocator)) { } template std::unordered_map LlamaTritonModelInstance::convert_inputs( std::shared_ptr> input_tensors) { FT_LOG_DEBUG(__PRETTY_FUNCTION__); move_tensor_H2D(input_tensors->at("input_ids"), d_input_ids_, &allocator_); move_tensor_H2D(input_tensors->at("input_lengths"), d_input_lengths_, &allocator_); const size_t request_batch_size = input_tensors->at("input_ids").shape[0]; const size_t input_data_len = input_tensors->at("input_ids").shape[1]; // freed in forward() h_total_output_lengths_ = reinterpret_cast(malloc(request_batch_size * sizeof(uint32_t))); std::unordered_map ft_input_tensors = std::unordered_map{ {"input_ids", as_GPU_tensor(input_tensors->at("input_ids"), d_input_ids_)}, // {"input_lengths", as_GPU_tensor(input_tensors->at("input_lengths"), d_input_lengths_)}, }; if (input_tensors->find("bad_words_list") != input_tensors->end()) { move_tensor_H2D(input_tensors->at("bad_words_list"), d_input_bad_words_, &allocator_); ft_input_tensors.insert( {"bad_words_list", as_GPU_tensor(input_tensors->at("bad_words_list"), d_input_bad_words_)}); } if (input_tensors->find("stop_words_list") != input_tensors->end()) { move_tensor_H2D(input_tensors->at("stop_words_list"), d_input_stop_words_, &allocator_); ft_input_tensors.insert( {"stop_words_list", as_GPU_tensor(input_tensors->at("stop_words_list"), d_input_stop_words_)}); } if (input_tensors->count("request_prompt_embedding") && input_tensors->count("request_prompt_lengths") && input_tensors->count("request_prompt_type")) { move_tensor_H2D(input_tensors->at("request_prompt_lengths"), d_request_prompt_lengths_, &allocator_); ft_input_tensors.insert( {"request_prompt_lengths", as_GPU_tensor(input_tensors->at("request_prompt_lengths"), d_request_prompt_lengths_)}); move_tensor_H2D(input_tensors->at("request_prompt_embedding"), d_request_prompt_embedding_, &allocator_); ft_input_tensors.insert( {"request_prompt_embedding", as_GPU_tensor(input_tensors->at("request_prompt_embedding"), d_request_prompt_embedding_)}); } if (input_tensors->find("top_p_decay") != input_tensors->end()) { move_tensor_H2D(input_tensors->at("top_p_decay"), d_top_p_decay_, &allocator_); ft_input_tensors.insert({"top_p_decay", as_GPU_tensor(input_tensors->at("top_p_decay"), d_top_p_decay_)}); } if (input_tensors->find("top_p_min") != input_tensors->end()) { move_tensor_H2D(input_tensors->at("top_p_min"), d_top_p_min_, &allocator_); ft_input_tensors.insert({"top_p_min", as_GPU_tensor(input_tensors->at("top_p_min"), d_top_p_min_)}); } if (input_tensors->find("top_p_reset_ids") != input_tensors->end()) { move_tensor_H2D(input_tensors->at("top_p_reset_ids"), d_top_p_reset_ids_, &allocator_); ft_input_tensors.insert( {"top_p_reset_ids", as_GPU_tensor(input_tensors->at("top_p_reset_ids"), d_top_p_reset_ids_)}); } for (auto t = input_tensors->begin(); t != input_tensors->end(); ++t) { if (t->first.find("input_ids") == std::string::npos // && t->first.find("input_lengths") == std::string::npos && t->first.find("output_seq_len") == std::string::npos && t->first.find("prefix_soft_prompt_embedding") == std::string::npos && t->first.find("prefix_soft_prompt_lengths") == std::string::npos) { if (ft_input_tensors.count(t->first) == 0) { ft_input_tensors.insert({t->first, t->second.convertTritonTensorToFt()}); } } } return ft_input_tensors; } template std::shared_ptr> LlamaTritonModelInstance::convert_outputs(const std::unordered_map& output_tensors) { FT_LOG_DEBUG(__PRETTY_FUNCTION__); std::unordered_map* outputs_mapping = new std::unordered_map(); for (auto it = output_tensors.begin(); it != output_tensors.end(); it++) { outputs_mapping->insert({it->first, triton::Tensor::convertFtTensorToTriton(it->second)}); } return std::shared_ptr>(outputs_mapping); } template std::shared_ptr> LlamaTritonModelInstance::forward(std::shared_ptr> input_tensors) { ft::FT_CHECK(false); return nullptr; } template std::shared_ptr> LlamaTritonModelInstance::forward(std::shared_ptr> input_tensors) { ft::FT_CHECK(false); return nullptr; } template std::string format_vector(const std::vector& vec) { std::stringstream ss; ss << "["; bool first = true; for (const auto& x : vec) { ss << (first ? "" : ", ") << x; first = false; } ss << "]"; return ss.str(); } template std::shared_ptr> LlamaTritonModelInstance::forward(std::shared_ptr> input_tensors, ft::AbstractInstanceComm* instance_comm) { FT_LOG_DEBUG(__PRETTY_FUNCTION__); // for (const auto& kv : *input_tensors) { // FT_LOG_INFO("%s: %s", kv.first.c_str(), format_vector(kv.second.shape).c_str()); // } FT_CHECK_WITH_INFO(input_tensors->at("input_ids").shape.size() == 2, "input_tensors->at(\"input_ids\").shape.size() == 2"); FT_CHECK_WITH_INFO(input_tensors->at("input_lengths").shape.size() == 1, "input_tensors->at(\"input_lengths\").shape.size() == 1"); const uint32_t request_batch_size = input_tensors->at("input_ids").shape[0]; const uint32_t max_request_output_len = (size_t)*std::max_element( (int*)input_tensors->at("request_output_len").data, (int*)input_tensors->at("request_output_len").data + input_tensors->at("request_output_len").shape[0]); // const uint32_t total_output_len = max_request_output_len + input_tensors->at("input_ids").shape[1]; const uint32_t beam_width = input_tensors->count("beam_width") ? (size_t)(*(uint*)input_tensors->at("beam_width").data) : 1; FT_CHECK_WITH_INFO(beam_width == 1, "Beam search is not implemented"); std::unordered_map ft_input_tensors = convert_inputs(input_tensors); allocateBuffer(request_batch_size, beam_width, instance_->session_len); std::unordered_map output_tensors = std::unordered_map{ {"output_ids", ft::Tensor{ft::MEMORY_GPU, ft::TYPE_UINT32, std::vector{request_batch_size, beam_width, (size_t)instance_->session_len}, d_output_ids_}}, {"sequence_length", ft::Tensor{ft::MEMORY_GPU, ft::TYPE_UINT32, std::vector{request_batch_size, beam_width}, d_sequence_lengths_}}}; if (input_tensors->count("is_return_log_probs") && *((bool*)input_tensors->at("is_return_log_probs").data)) { output_tensors.insert({"output_log_probs", ft::Tensor{ft::MEMORY_GPU, ft::TYPE_FP32, std::vector{request_batch_size, beam_width, max_request_output_len}, d_output_log_probs_}}); output_tensors.insert({"cum_log_probs", ft::Tensor{ft::MEMORY_GPU, ft::TYPE_FP32, std::vector{request_batch_size, beam_width}, d_cum_log_probs_}}); } try { ft::Request::Callback callback; if (stream_cb_) { callback = [this](std::unordered_map* outputs) { triton_stream_callback(outputs, this); }; } ft::check_cuda_error(cudaStreamSynchronize(allocator_->returnStream())); instance_->llm->forward(&output_tensors, &ft_input_tensors, {instance_comm, callback}); // ! stream synced by the model before returning } catch (...) { h_exception_ = std::current_exception(); output_tensors.insert({"error_message", ft::Tensor{ft::MEMORY_CPU, ft::TYPE_BYTES, {1}, &h_exception_}}); } if (h_total_output_lengths_ != nullptr) { free(h_total_output_lengths_); h_total_output_lengths_ = nullptr; } return convert_outputs(output_tensors); } template LlamaTritonModelInstance::~LlamaTritonModelInstance() { freeBuffer(); } template void LlamaTritonModelInstance::allocateBuffer(const size_t request_batch_size, const size_t beam_width, const size_t session_len) { d_output_ids_ = (int*)(allocator_->reMalloc(d_output_ids_, sizeof(int) * request_batch_size * beam_width * session_len, false)); d_sequence_lengths_ = (int*)(allocator_->reMalloc(d_sequence_lengths_, sizeof(int) * request_batch_size * beam_width, false)); d_output_log_probs_ = (float*)(allocator_->reMalloc( d_output_log_probs_, sizeof(float) * request_batch_size * beam_width * session_len, false)); d_cum_log_probs_ = (float*)(allocator_->reMalloc(d_cum_log_probs_, sizeof(float) * request_batch_size * beam_width, false)); } template void LlamaTritonModelInstance::freeBuffer() { allocator_->free((void**)(&d_output_ids_)); allocator_->free((void**)(&d_sequence_lengths_)); allocator_->free((void**)(&d_output_log_probs_)); allocator_->free((void**)(&d_cum_log_probs_)); } template struct LlamaTritonModelInstance; template struct LlamaTritonModelInstance;