Unverified Commit 0cc9d095 authored by Li Zhang's avatar Li Zhang Committed by GitHub
Browse files

[Feature] decode-only forward pass (#153)

* decode only forward pass

* fix lint

* batch embedding
parent 7b470f07
# Copyright (c) OpenMMLab. All rights reserved.
import os
import os.path as osp
import fire
import torch
from lmdeploy import turbomind as tm
from lmdeploy.turbomind.tokenizer import Tokenizer
os.environ['TM_LOG_LEVEL'] = 'ERROR'
def main(model_path, inputs):
"""An example to perform model inference through the command line
interface.
Args:
model_path (str): the path of the deployed model
inputs (str): the path of text file contatin input text lines
"""
tokenizer_model_path = osp.join(model_path, 'triton_models', 'tokenizer')
tokenizer = Tokenizer(tokenizer_model_path)
tm_model = tm.TurboMind(model_path, eos_id=tokenizer.eos_token_id)
generator = tm_model.create_instance()
with open(inputs, 'r') as f:
lines = f.readlines()
input_ids = [tokenizer.encode(x) for x in lines]
logits = generator.decode(input_ids)
top_1 = torch.argmax(logits, -1)
print(top_1)
if __name__ == '__main__':
fire.Fire(main)
...@@ -341,3 +341,52 @@ class TurboMindInstance: ...@@ -341,3 +341,52 @@ class TurboMindInstance:
if stream_output: if stream_output:
self.model_insts[0].unregister_callback() self.model_insts[0].unregister_callback()
def decode(self, input_ids):
"""Perform context decode on input tokens.
Args:
input_ids (numpy.ndarray): the batch of input token ids
"""
if len(input_ids) == 0:
input_ids = []
if isinstance(input_ids[0], int):
input_ids = [input_ids]
# append an extra token since input_len-1 tokens will be
# decoded by context decoder
for inputs in input_ids:
inputs.append(0)
batch_size = len(input_ids)
def _broadcast_np(data, dtype, shape=(batch_size, )):
if isinstance(data, Iterable):
assert len(data) == batch_size
return data
return np.full(shape, data, dtype=dtype)
input_ids = [torch.IntTensor(ids) for ids in input_ids]
input_lengths = torch.IntTensor([len(ids) for ids in input_ids])
input_ids = pad_sequence(input_ids,
batch_first=True,
padding_value=self.eos_id)
inputs = dict(input_ids=input_ids,
input_lengths=input_lengths,
request_output_len=_broadcast_np(0, dtype=np.uint32),
is_return_logits=_broadcast_np(1, np.uint32))
tm_inputs = _np_dict_to_tm_dict(inputs)
# start forward thread
self._forward_thread(tm_inputs)
_, tm_outputs = self.que.get()
outputs = _tm_dict_to_torch_dict(tm_outputs)
logits = outputs['logits']
return logits[:, :-1, :]
...@@ -155,6 +155,8 @@ void LlamaBatch<T>::allocateBuffer(size_t batch_size, size_t session_len) ...@@ -155,6 +155,8 @@ void LlamaBatch<T>::allocateBuffer(size_t batch_size, size_t session_len)
context_decoder_input_buf_ = context_decoder_input_buf_ =
(T*)allocator_->reMalloc(context_decoder_input_buf_, sizeof(T) * max_context_token_num_ * hidden_units, false); (T*)allocator_->reMalloc(context_decoder_input_buf_, sizeof(T) * max_context_token_num_ * hidden_units, false);
context_decoder_output_buf_ =
(T*)allocator_->reMalloc(context_decoder_output_buf_, sizeof(T) * max_context_token_num_ * hidden_units, false);
context_decoder_ids_buf_ = context_decoder_ids_buf_ =
(int*)allocator_->reMalloc(context_decoder_ids_buf_, sizeof(int) * max_context_token_num_, false); (int*)allocator_->reMalloc(context_decoder_ids_buf_, sizeof(int) * max_context_token_num_, false);
...@@ -242,6 +244,7 @@ void LlamaBatch<T>::freeBuffer() ...@@ -242,6 +244,7 @@ void LlamaBatch<T>::freeBuffer()
TM_LOG_DEBUG(__PRETTY_FUNCTION__); TM_LOG_DEBUG(__PRETTY_FUNCTION__);
if (is_allocate_buffer_) { if (is_allocate_buffer_) {
allocator_->free((void**)&context_decoder_input_buf_); allocator_->free((void**)&context_decoder_input_buf_);
allocator_->free((void**)&context_decoder_output_buf_);
allocator_->free((void**)&context_decoder_ids_buf_); allocator_->free((void**)&context_decoder_ids_buf_);
allocator_->free((void**)&decoder_input_buf_); allocator_->free((void**)&decoder_input_buf_);
...@@ -261,6 +264,13 @@ void LlamaBatch<T>::freeBuffer() ...@@ -261,6 +264,13 @@ void LlamaBatch<T>::freeBuffer()
allocator_->free((void**)&logits_buf_); allocator_->free((void**)&logits_buf_);
allocator_->free((void**)&local_logits_buf_); allocator_->free((void**)&local_logits_buf_);
if (local_context_logits_buf_) {
allocator_->free((void**)&local_context_logits_buf_);
}
if (context_logits_buf_) {
allocator_->free((void**)&context_logits_buf_);
}
allocator_->free((void**)&token_ids_buf_); allocator_->free((void**)&token_ids_buf_);
allocator_->free((void**)&end_ids_buf_); allocator_->free((void**)&end_ids_buf_);
...@@ -774,6 +784,9 @@ void LlamaBatch<T>::contextDecode() ...@@ -774,6 +784,9 @@ void LlamaBatch<T>::contextDecode()
auto get_input_len = [this](int index) { return h_input_length_buf_[index] - 1; }; auto get_input_len = [this](int index) { return h_input_length_buf_[index] - 1; };
auto get_context_len = [this](int index) { return h_context_length_buf_[index] - 1; }; auto get_context_len = [this](int index) { return h_context_length_buf_[index] - 1; };
std::vector<int> decode_indices{base};
std::vector<int> decode_lengths{get_input_len(base)};
auto token_num = get_input_len(base); auto token_num = get_input_len(base);
auto max_input_len = get_input_len(base); auto max_input_len = get_input_len(base);
auto max_context_len = get_context_len(base); auto max_context_len = get_context_len(base);
...@@ -807,7 +820,7 @@ void LlamaBatch<T>::contextDecode() ...@@ -807,7 +820,7 @@ void LlamaBatch<T>::contextDecode()
k_cache_ptr_buf_ + offset, k_cache_ptr_buf_ + offset,
v_cache_ptr_buf_ + offset, v_cache_ptr_buf_ + offset,
context_decoder_input_buf_, context_decoder_input_buf_,
nullptr, context_decoder_output_buf_,
context_decoder_ids_buf_, context_decoder_ids_buf_,
input_length_buf_ + offset, input_length_buf_ + offset,
history_length_buf_ + offset, history_length_buf_ + offset,
...@@ -817,17 +830,29 @@ void LlamaBatch<T>::contextDecode() ...@@ -817,17 +830,29 @@ void LlamaBatch<T>::contextDecode()
max_context_len, max_context_len,
session_len_, session_len_,
context_decode_batch_size); context_decode_batch_size);
// compute logits of inputs if requested
outputContextLogits(context_decoder_output_buf_, decode_indices, decode_lengths);
if (i < batch_size_) { if (i < batch_size_) {
// initialize next sub-batch
token_num = get_input_len(i); token_num = get_input_len(i);
max_input_len = get_input_len(i); max_input_len = get_input_len(i);
max_context_len = get_context_len(i); max_context_len = get_context_len(i);
offset = i; offset = i;
decode_indices = {i};
decode_lengths = {get_input_len(i)};
} }
} }
else { else {
// add to current sub-batch
token_num += get_input_len(i); token_num += get_input_len(i);
max_input_len = std::max(max_input_len, get_input_len(i)); max_input_len = std::max(max_input_len, get_input_len(i));
max_context_len = std::max(max_context_len, get_context_len(i)); max_context_len = std::max(max_context_len, get_context_len(i));
decode_indices.push_back(i);
decode_lengths.push_back(get_input_len(i));
} }
} }
...@@ -849,6 +874,56 @@ void LlamaBatch<T>::contextDecode() ...@@ -849,6 +874,56 @@ void LlamaBatch<T>::contextDecode()
} }
} }
template<typename T>
void LlamaBatch<T>::outputContextLogits(T* context_decoder_output,
const std::vector<int>& indices,
const std::vector<int>& lengths)
{
std::vector<float*> output_logits;
int num_token = 0;
{
bool is_return_logits = false;
for (int k = 0; k < indices.size(); ++k) {
auto& request = requests_[indices[k]];
output_logits.push_back(request->outputs[rank_].getPtr<float>("logits", nullptr));
num_token += lengths[k];
if (output_logits.back()) {
is_return_logits = true;
}
}
if (!is_return_logits) {
return;
}
}
if (context_logits_buf_ == nullptr) {
NcclGuard guard(llama_->tensor_para_, stream_, true);
context_logits_buf_ = (float*)allocator_->malloc(sizeof(float) * llama_->vocab_size_ * max_context_token_num_);
const auto tp = llama_->tensor_para_.world_size_;
if (tp > 1) {
FT_CHECK(llama_->vocab_size_ % tp == 0);
const auto local_vocab_size = llama_->vocab_size_ / tp;
local_context_logits_buf_ =
(float*)allocator_->malloc(sizeof(float) * local_vocab_size * max_context_token_num_);
}
}
llama_->postDecodeEmbedding(context_logits_buf_, local_context_logits_buf_, context_decoder_output, num_token);
auto logits = context_logits_buf_;
for (int k = 0; k < indices.size(); ++k) {
if (output_logits[k]) {
check_cuda_error(cudaMemcpyAsync(output_logits[k],
logits,
sizeof(float) * llama_->vocab_size_ * lengths[k],
cudaMemcpyDefault,
stream_));
}
logits += llama_->vocab_size_ * lengths[k];
}
}
template<typename T> template<typename T>
void LlamaBatch<T>::finish() void LlamaBatch<T>::finish()
{ {
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#pragma once #pragma once
#include "src/turbomind/models/llama/LlamaCacheManager.h" #include "src/turbomind/models/llama/LlamaCacheManager.h"
#include "src/turbomind/models/llama/LlamaNcclGuard.h"
#include "src/turbomind/models/llama/Request.h" #include "src/turbomind/models/llama/Request.h"
#include "src/turbomind/utils/allocator.h" #include "src/turbomind/utils/allocator.h"
#include "src/turbomind/utils/cublasMMWrapper.h" #include "src/turbomind/utils/cublasMMWrapper.h"
...@@ -53,6 +54,9 @@ public: ...@@ -53,6 +54,9 @@ public:
void setOutputTensors(int max_gen_step); void setOutputTensors(int max_gen_step);
void
outputContextLogits(T* context_decoder_output, const std::vector<int>& indices, const std::vector<int>& lengths);
explicit LlamaBatch(int max_batch_size, int max_context_token_num, int session_len, LlamaV2<T>* llama); explicit LlamaBatch(int max_batch_size, int max_context_token_num, int session_len, LlamaV2<T>* llama);
~LlamaBatch() ~LlamaBatch()
...@@ -72,8 +76,8 @@ private: ...@@ -72,8 +76,8 @@ private:
// active requests // active requests
std::vector<std::shared_ptr<Request>> requests_; std::vector<std::shared_ptr<Request>> requests_;
T* context_decoder_input_buf_{}; // CTXDEC T* context_decoder_input_buf_{}; // CTXDEC
// T* context_decoder_output_buf_{}; // CTXDEC T* context_decoder_output_buf_{}; // CTXDEC
int* context_decoder_ids_buf_{}; int* context_decoder_ids_buf_{};
T* decoder_input_buf_{}; // CTXDEC, GENERATE T* decoder_input_buf_{}; // CTXDEC, GENERATE
...@@ -92,6 +96,8 @@ private: ...@@ -92,6 +96,8 @@ private:
float* logits_buf_{}; // combined logits float* logits_buf_{}; // combined logits
float* local_logits_buf_{}; // tensor parallel local logits float* local_logits_buf_{}; // tensor parallel local logits
float* context_logits_buf_{};
float* local_context_logits_buf_{};
// used by dynamic decoder // used by dynamic decoder
int* token_ids_buf_{}; // all token IDs in [S, B], indexed using `step` int* token_ids_buf_{}; // all token IDs in [S, B], indexed using `step`
......
...@@ -39,7 +39,6 @@ void LlamaContextDecoder<T>::allocateBuffer(size_t batch_size, size_t num_token, ...@@ -39,7 +39,6 @@ void LlamaContextDecoder<T>::allocateBuffer(size_t batch_size, size_t num_token,
{ {
TM_LOG_DEBUG(__PRETTY_FUNCTION__); TM_LOG_DEBUG(__PRETTY_FUNCTION__);
attn_ffn_io_ = (T*)allocator_->reMalloc(attn_ffn_io_, sizeof(T) * num_token * hidden_units_, false);
attention_mask_ = (T*)allocator_->reMalloc(attention_mask_, sizeof(T) * batch_size * max_q_len * max_kv_len, false); attention_mask_ = (T*)allocator_->reMalloc(attention_mask_, sizeof(T) * batch_size * max_q_len * max_kv_len, false);
padding_offset_ = (int*)allocator_->reMalloc(padding_offset_, sizeof(int) * batch_size * max_q_len, false); padding_offset_ = (int*)allocator_->reMalloc(padding_offset_, sizeof(int) * batch_size * max_q_len, false);
cu_seqlens_ = (int*)allocator_->reMalloc(cu_seqlens_, sizeof(int) * (batch_size + 1), false); cu_seqlens_ = (int*)allocator_->reMalloc(cu_seqlens_, sizeof(int) * (batch_size + 1), false);
...@@ -52,7 +51,6 @@ void LlamaContextDecoder<T>::freeBuffer() ...@@ -52,7 +51,6 @@ void LlamaContextDecoder<T>::freeBuffer()
{ {
TM_LOG_DEBUG(__PRETTY_FUNCTION__); TM_LOG_DEBUG(__PRETTY_FUNCTION__);
if (is_allocate_buffer_) { if (is_allocate_buffer_) {
allocator_->free((void**)&attn_ffn_io_);
allocator_->free((void**)&padding_offset_); allocator_->free((void**)&padding_offset_);
allocator_->free((void**)&cu_seqlens_); allocator_->free((void**)&cu_seqlens_);
allocator_->free((void**)&attention_mask_); allocator_->free((void**)&attention_mask_);
...@@ -91,13 +89,14 @@ void LlamaContextDecoder<T>::initialize(size_t kv_head_num, bool use_fmha, int q ...@@ -91,13 +89,14 @@ void LlamaContextDecoder<T>::initialize(size_t kv_head_num, bool use_fmha, int q
template<typename T> template<typename T>
void LlamaContextDecoder<T>::forwardSelfAttn(const Session& sess, void LlamaContextDecoder<T>::forwardSelfAttn(const Session& sess,
T* attn_io,
const std::unordered_map<std::string, Tensor>* input_tensors, const std::unordered_map<std::string, Tensor>* input_tensors,
int layer, int layer,
bool is_final) bool is_final)
{ {
// TM_LOG_ERROR(__PRETTY_FUNCTION__); // TM_LOG_ERROR(__PRETTY_FUNCTION__);
TensorMap self_attention_input_tensors{ TensorMap self_attention_input_tensors{
{"input_query", Tensor{MEMORY_GPU, data_type_, {sess.token_num, hidden_units_}, attn_ffn_io_}}, {"input_query", Tensor{MEMORY_GPU, data_type_, {sess.token_num, hidden_units_}, attn_io}},
{"attention_mask", {"attention_mask",
{MEMORY_GPU, data_type_, {sess.batch_size, 1, sess.max_query_len, sess.max_key_len}, attention_mask_}}, {MEMORY_GPU, data_type_, {sess.batch_size, 1, sess.max_query_len, sess.max_key_len}, attention_mask_}},
{"layer_id", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &layer}}, {"layer_id", Tensor{MEMORY_CPU, TYPE_INT32, {1}, &layer}},
...@@ -113,7 +112,7 @@ void LlamaContextDecoder<T>::forwardSelfAttn(const Session& ...@@ -113,7 +112,7 @@ void LlamaContextDecoder<T>::forwardSelfAttn(const Session&
auto& v_cache = *sess.v_cache; auto& v_cache = *sess.v_cache;
TensorMap self_attention_output_tensors{ TensorMap self_attention_output_tensors{
{"hidden_features", {MEMORY_GPU, data_type_, {sess.token_num, hidden_units_}, attn_ffn_io_}}, {"hidden_features", {MEMORY_GPU, data_type_, {sess.token_num, hidden_units_}, attn_io}},
{"key_cache", k_cache}, {"key_cache", k_cache},
{"value_cache", v_cache}, {"value_cache", v_cache},
}; };
...@@ -185,7 +184,7 @@ void LlamaContextDecoder<T>::forward(std::unordered_map<std::string, Tensor>* ...@@ -185,7 +184,7 @@ void LlamaContextDecoder<T>::forward(std::unordered_map<std::string, Tensor>*
* \param max_seq_len [1], int on cpu * \param max_seq_len [1], int on cpu
* *
* output tensors: * output tensors:
* \param decoder_output [batch_size, seq_len, hidden_units], * \param decoder_output [num_token, hidden_units],
* \param key_cache [num_layer, batch, local_head_num, size_per_head // x, max_seq_len, x] * \param key_cache [num_layer, batch, local_head_num, size_per_head // x, max_seq_len, x]
* \param value_cache [num_layer, batch, local_head_num, max_seq_len, size_per_head] * \param value_cache [num_layer, batch, local_head_num, max_seq_len, size_per_head]
* \param last_token_hidden_units [batch_size, hidden_units] * \param last_token_hidden_units [batch_size, hidden_units]
...@@ -204,7 +203,7 @@ void LlamaContextDecoder<T>::forward(std::unordered_map<std::string, Tensor>* ...@@ -204,7 +203,7 @@ void LlamaContextDecoder<T>::forward(std::unordered_map<std::string, Tensor>*
sess.context_length = input_tensors->at("context_lengths").getPtr<int>(); sess.context_length = input_tensors->at("context_lengths").getPtr<int>();
T* decoder_input_output = input_tensors->at("decoder_input").getPtr<T>(); T* decoder_input_output = input_tensors->at("decoder_input").getPtr<T>();
// T* decoder_output = output_tensors->at("decoder_output").getPtr<T>(); T* decoder_output = output_tensors->at("decoder_output").getPtr<T>();
sess.k_cache = &output_tensors->at("key_cache"); sess.k_cache = &output_tensors->at("key_cache");
sess.v_cache = &output_tensors->at("value_cache"); sess.v_cache = &output_tensors->at("value_cache");
...@@ -234,7 +233,7 @@ void LlamaContextDecoder<T>::forward(std::unordered_map<std::string, Tensor>* ...@@ -234,7 +233,7 @@ void LlamaContextDecoder<T>::forward(std::unordered_map<std::string, Tensor>*
///////////////////////////////////////////// /////////////////////////////////////////////
/// RMSNorm /// RMSNorm
invokeRootMeanSquareNorm(attn_ffn_io_, invokeRootMeanSquareNorm(decoder_output,
decoder_input_output, decoder_input_output,
decoder_layer_weights->at(0)->self_attn_norm_weights, decoder_layer_weights->at(0)->self_attn_norm_weights,
rmsnorm_eps_, rmsnorm_eps_,
...@@ -246,10 +245,10 @@ void LlamaContextDecoder<T>::forward(std::unordered_map<std::string, Tensor>* ...@@ -246,10 +245,10 @@ void LlamaContextDecoder<T>::forward(std::unordered_map<std::string, Tensor>*
for (size_t layer = 0; layer < num_layer_; ++layer) { for (size_t layer = 0; layer < num_layer_; ++layer) {
///////////////////////////////////////////// /////////////////////////////////////////////
/// self-attention /// self-attention
forwardSelfAttn(sess, input_tensors, layer, false); forwardSelfAttn(sess, decoder_output, input_tensors, layer, false);
invokeFusedAddBiasResidualRMSNorm(decoder_input_output, invokeFusedAddBiasResidualRMSNorm(decoder_input_output,
attn_ffn_io_, decoder_output,
decoder_layer_weights->at(layer)->self_attn_weights.output.bias, decoder_layer_weights->at(layer)->self_attn_weights.output.bias,
decoder_layer_weights->at(layer)->ffn_norm_weights, decoder_layer_weights->at(layer)->ffn_norm_weights,
rmsnorm_eps_, rmsnorm_eps_,
...@@ -260,14 +259,15 @@ void LlamaContextDecoder<T>::forward(std::unordered_map<std::string, Tensor>* ...@@ -260,14 +259,15 @@ void LlamaContextDecoder<T>::forward(std::unordered_map<std::string, Tensor>*
//////////////////////////////////////////// ////////////////////////////////////////////
/// feed-forward network /// feed-forward network
TensorMap ffn_inputs{{"ffn_input", {MEMORY_GPU, data_type_, {sess.token_num, hidden_units_}, attn_ffn_io_}}}; TensorMap ffn_inputs{{"ffn_input", {MEMORY_GPU, data_type_, {sess.token_num, hidden_units_}, decoder_output}}};
TensorMap ffn_outputs{{"ffn_output", {MEMORY_GPU, data_type_, {sess.token_num, hidden_units_}, attn_ffn_io_}}}; TensorMap ffn_outputs{
{"ffn_output", {MEMORY_GPU, data_type_, {sess.token_num, hidden_units_}, decoder_output}}};
silu_ffn_layer_->forward(&ffn_outputs, &ffn_inputs, &decoder_layer_weights->at(layer)->ffn_weights); silu_ffn_layer_->forward(&ffn_outputs, &ffn_inputs, &decoder_layer_weights->at(layer)->ffn_weights);
auto scale_weight = layer < num_layer_ - 1 ? decoder_layer_weights->at(layer + 1)->self_attn_norm_weights : auto scale_weight = layer < num_layer_ - 1 ? decoder_layer_weights->at(layer + 1)->self_attn_norm_weights :
input_tensors->at("output_norm_weight").getPtr<T>(); input_tensors->at("output_norm_weight").getPtr<T>();
invokeFusedAddBiasResidualRMSNorm(decoder_input_output, // invokeFusedAddBiasResidualRMSNorm(decoder_input_output, //
attn_ffn_io_, decoder_output,
decoder_layer_weights->at(layer)->ffn_weights.output.bias, decoder_layer_weights->at(layer)->ffn_weights.output.bias,
scale_weight, scale_weight,
rmsnorm_eps_, rmsnorm_eps_,
......
...@@ -55,7 +55,6 @@ protected: ...@@ -55,7 +55,6 @@ protected:
NcclParam tensor_para_; NcclParam tensor_para_;
T* attn_ffn_io_{};
T* attention_mask_{}; T* attention_mask_{};
int* padding_offset_{}; int* padding_offset_{};
int* cu_seqlens_{}; // cu for cumulative int* cu_seqlens_{}; // cu for cumulative
...@@ -82,6 +81,7 @@ protected: ...@@ -82,6 +81,7 @@ protected:
}; };
void forwardSelfAttn(const Session& sess, void forwardSelfAttn(const Session& sess,
T* attn_io,
const std::unordered_map<std::string, Tensor>* input_tensors, const std::unordered_map<std::string, Tensor>* input_tensors,
int layer, int layer,
bool is_final); bool is_final);
......
...@@ -86,6 +86,11 @@ public: ...@@ -86,6 +86,11 @@ public:
void stop(const std::vector<uint64_t>& seq_ids); void stop(const std::vector<uint64_t>& seq_ids);
size_t vocab_size() const noexcept
{
return vocab_size_;
}
private: private:
friend class Batch; friend class Batch;
......
...@@ -193,7 +193,13 @@ LlamaTritonModelInstance<T>::forward(std::shared_ptr<std::unordered_map<std::str ...@@ -193,7 +193,13 @@ LlamaTritonModelInstance<T>::forward(std::shared_ptr<std::unordered_map<std::str
std::unordered_map<std::string, ft::Tensor> ft_input_tensors = convert_inputs(input_tensors); std::unordered_map<std::string, ft::Tensor> ft_input_tensors = convert_inputs(input_tensors);
allocateBuffer(request_batch_size, beam_width, instance_->session_len); const size_t max_input_len = input_tensors->at("input_ids").shape[1];
const bool is_return_logits =
input_tensors->count("is_return_logits") && *(bool*)input_tensors->at("is_return_logits").data;
const size_t vocab_size = instance_->llm->vocab_size();
allocateBuffer(request_batch_size, max_input_len, beam_width, instance_->session_len, is_return_logits);
std::unordered_map<std::string, ft::Tensor> output_tensors = std::unordered_map<std::string, ft::Tensor>{ std::unordered_map<std::string, ft::Tensor> output_tensors = std::unordered_map<std::string, ft::Tensor>{
{"output_ids", {"output_ids",
...@@ -219,6 +225,13 @@ LlamaTritonModelInstance<T>::forward(std::shared_ptr<std::unordered_map<std::str ...@@ -219,6 +225,13 @@ LlamaTritonModelInstance<T>::forward(std::shared_ptr<std::unordered_map<std::str
std::vector<size_t>{request_batch_size, beam_width}, std::vector<size_t>{request_batch_size, beam_width},
d_cum_log_probs_}}); d_cum_log_probs_}});
} }
if (is_return_logits) {
output_tensors.insert(
{"logits",
{ft::MEMORY_GPU, ft::TYPE_FP32, {request_batch_size, max_input_len, vocab_size}, d_output_logits_}});
}
try { try {
ft::Request::Callback callback; ft::Request::Callback callback;
...@@ -253,8 +266,10 @@ LlamaTritonModelInstance<T>::~LlamaTritonModelInstance() ...@@ -253,8 +266,10 @@ LlamaTritonModelInstance<T>::~LlamaTritonModelInstance()
template<typename T> template<typename T>
void LlamaTritonModelInstance<T>::allocateBuffer(const size_t request_batch_size, void LlamaTritonModelInstance<T>::allocateBuffer(const size_t request_batch_size,
const size_t max_input_len,
const size_t beam_width, const size_t beam_width,
const size_t session_len) const size_t session_len,
const bool is_return_logits)
{ {
d_output_ids_ = d_output_ids_ =
(int*)(allocator_->reMalloc(d_output_ids_, sizeof(int) * request_batch_size * beam_width * session_len, false)); (int*)(allocator_->reMalloc(d_output_ids_, sizeof(int) * request_batch_size * beam_width * session_len, false));
...@@ -264,6 +279,10 @@ void LlamaTritonModelInstance<T>::allocateBuffer(const size_t request_batch_size ...@@ -264,6 +279,10 @@ void LlamaTritonModelInstance<T>::allocateBuffer(const size_t request_batch_size
d_output_log_probs_, sizeof(float) * request_batch_size * beam_width * session_len, false)); d_output_log_probs_, sizeof(float) * request_batch_size * beam_width * session_len, false));
d_cum_log_probs_ = d_cum_log_probs_ =
(float*)(allocator_->reMalloc(d_cum_log_probs_, sizeof(float) * request_batch_size * beam_width, false)); (float*)(allocator_->reMalloc(d_cum_log_probs_, sizeof(float) * request_batch_size * beam_width, false));
if (is_return_logits) {
d_output_logits_ = (float*)allocator_->reMalloc(
d_output_logits_, sizeof(float) * request_batch_size * max_input_len * instance_->llm->vocab_size(), false);
}
} }
template<typename T> template<typename T>
......
...@@ -66,7 +66,11 @@ private: ...@@ -66,7 +66,11 @@ private:
std::unordered_map<std::string, ft::Tensor> std::unordered_map<std::string, ft::Tensor>
convert_inputs(std::shared_ptr<std::unordered_map<std::string, triton::Tensor>> input_tensors); convert_inputs(std::shared_ptr<std::unordered_map<std::string, triton::Tensor>> input_tensors);
void allocateBuffer(const size_t request_batch_size, const size_t beam_width, const size_t session_len); void allocateBuffer(const size_t request_batch_size,
const size_t max_input_len,
const size_t beam_width,
const size_t session_len,
const bool is_return_logits);
void freeBuffer(); void freeBuffer();
int* d_input_ids_ = nullptr; int* d_input_ids_ = nullptr;
...@@ -83,6 +87,7 @@ private: ...@@ -83,6 +87,7 @@ private:
int* d_sequence_lengths_ = nullptr; int* d_sequence_lengths_ = nullptr;
float* d_output_log_probs_ = nullptr; float* d_output_log_probs_ = nullptr;
float* d_cum_log_probs_ = nullptr; float* d_cum_log_probs_ = nullptr;
float* d_output_logits_ = nullptr;
uint32_t* h_total_output_lengths_ = nullptr; uint32_t* h_total_output_lengths_ = nullptr;
std::exception_ptr h_exception_ = nullptr; std::exception_ptr h_exception_ = nullptr;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment