Unverified Commit b190521b authored by Chen Xin's avatar Chen Xin Committed by GitHub
Browse files

support image_embs input (#799)

* support image_embs input

* add some checks

* update interactive/config.pbtxt && TurbomindModelConfig

* update docstring

* refactor

* support convert embeddings to bf16

* update interactive/config.pbtxt

* embeddings -> input_embeddings

* use input_embedding_ranges

* remove embedding_begins/ends
parent af5a3edb
......@@ -59,6 +59,18 @@ input [
data_type: TYPE_UINT32
dims: [ -1 ]
},
{
name: "input_embeddings"
data_type: TYPE_INT8
dims: [ -1 ]
optional: true
},
{
name: "input_embedding_ranges"
data_type: TYPE_UINT32
dims: [ -1, 2 ]
optional: true
},
{
name: "step"
data_type: TYPE_INT32
......
......@@ -459,6 +459,8 @@ class TurboMindInstance:
def stream_infer(self,
session_id,
input_ids,
input_embeddings=None,
input_embedding_ranges=None,
request_output_len: int = 512,
sequence_start: bool = True,
sequence_end: bool = False,
......@@ -476,6 +478,9 @@ class TurboMindInstance:
Args:
session_id (int): the id of a session
input_ids (numpy.ndarray): the token ids of a prompt
input_embeddings (List[numpy.ndarray]): embeddings features
input_embedding_ranges (List[Tuple[int,int]]): the begin/end
offsets of input_embeddings to input_ids
request_output_len (int): the max number of to-be-generated tokens
sequence_start (bool): indicator for starting a sequence
sequence_end (bool): indicator for ending a sequence
......@@ -544,6 +549,44 @@ class TurboMindInstance:
CORRID=np.array(session_id, dtype=np.uint64),
STOP=_broadcast_np((1 if stop else 0), np.int32))
if input_embeddings is not None:
assert len(input_embeddings) == len(input_embedding_ranges)
if isinstance(input_embeddings[0], np.ndarray):
input_embeddings = [input_embeddings]
input_embedding_ranges = [input_embedding_ranges]
# convert to lookup table type
if self.tm_model.config.weight_type == 'fp32':
input_embeddings = [[x.astype(np.float32) for x in y]
for y in input_embeddings]
elif self.tm_model.config.weight_type == 'bf16':
input_embeddings = [[
torch.from_numpy(x).bfloat16().view(torch.half).numpy()
for x in y
] for y in input_embeddings]
else:
input_embeddings = [[x.astype(np.float16) for x in y]
for y in input_embeddings]
input_embeddings = [[torch.from_numpy(x).squeeze() for x in y]
for y in input_embeddings]
input_embeddings = [torch.cat(x) for x in input_embeddings]
input_embeddings = pad_sequence(input_embeddings, batch_first=True)
input_embeddings = input_embeddings.reshape(
input_embeddings.shape[0], -1).view(torch.int8)
_input_embedding_ranges = []
for x in input_embedding_ranges:
if x is not None and len(x) != 0:
_input_embedding_ranges.append(torch.IntTensor(x))
else:
_input_embedding_ranges.append(torch.IntTensor(size=(0,
2)))
input_embedding_ranges = pad_sequence(_input_embedding_ranges,
batch_first=True,
padding_value=-1)
inputs['input_embeddings'] = input_embeddings
inputs['input_embedding_ranges'] = input_embedding_ranges
if ignore_eos:
stop_words = None
bad_words = torch.tensor([[[self.eos_id], [1]]], dtype=torch.int32)
......
......@@ -60,6 +60,21 @@ void ClearState(BatchState& s)
s.size = s.active_size = 0;
}
void DropEmbeddings(const Sequence& seq)
{
int seq_len = seq.tokens.size();
int num_emb = seq.input_embeddings.size();
size_t sz = num_emb;
for (; sz >= 1; sz--) {
if (seq.input_embedding_ranges[sz - 1].second <= seq_len) {
break;
}
}
// should we keep part of embedding?
seq.input_embeddings.resize(sz);
seq.input_embedding_ranges.resize(sz);
}
template<typename T>
void LlamaBatch<T>::RejectInvalidRequests(Requests& stop_reqs, Requests& infer_reqs)
{
......@@ -234,6 +249,7 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests)
if (step <= seq.tokens.size()) {
seq.tokens.resize(step);
seq.cache_len = std::min(seq.cache_len, step);
DropEmbeddings(seq);
}
else if (rank_ == 0) {
TM_LOG_WARNING(
......@@ -258,6 +274,59 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests)
output_ids = Copy(input_ids, input_length, output_ids);
}
// copy input embeddings
if (r->inputs[rank_].isExist("input_embedding_ranges")) {
const auto range_tensor = r->inputs[rank_].at("input_embedding_ranges");
const auto emb_tensor = r->inputs[rank_].at("input_embeddings");
const int* ranges = range_tensor.getPtr<int>();
auto check_embeddings = [&](int& num_valid_embeddings) {
if (range_tensor.shape.size() != 3 || range_tensor.shape[2] % 2 != 0) {
return false;
}
int embedding_count = range_tensor.shape[1];
int embedding_length = 0;
int pre_end = -1;
for (size_t i = 0; i < embedding_count; i++) {
int begin = ranges[i * 2];
int end = ranges[i * 2 + 1];
embedding_length += (end - begin);
if (begin < 0 || end < 0) {
break;
}
if (begin >= end || end > input_length || begin < pre_end
|| embedding_length * model_->hidden_units_ * sizeof(T) > emb_tensor.shape[1]) {
return false;
}
pre_end = end;
num_valid_embeddings = i + 1;
}
return true;
};
int num_valid_embeddings = 0;
if (!check_embeddings(num_valid_embeddings)) {
TM_LOG_WARNING("[ImageFeature] Skip invalid input embeddings, id = %ld, input_length = %d, "
"input embeddings = %s, range_tensor = %s",
(long)seq.id,
input_length,
emb_tensor.toString().c_str(),
range_tensor.toString().c_str());
}
else {
char* emb_tensor_ptr = emb_tensor.getPtr<char>();
for (size_t i = 0; i < num_valid_embeddings; i++) {
int begin = ranges[i * 2];
int end = ranges[i * 2 + 1];
size_t count = (end - begin) * model_->hidden_units_ * sizeof(T);
seq.input_embeddings.emplace_back((std::byte*)emb_tensor_ptr, (std::byte*)(emb_tensor_ptr + count));
seq.input_embedding_ranges.emplace_back(begin + seq.tokens.size(), end + seq.tokens.size());
emb_tensor_ptr += count;
}
}
}
// total context length (history + input)
state.h_context_length[idx] = output_ids - output_ids_base;
state.h_finished[idx] = false;
......@@ -1422,6 +1491,8 @@ bool LlamaBatch<T>::Forward(GenerationState& g, int iter)
std::vector<int> decode_indices{};
std::vector<int> decode_lengths{};
std::vector<const Sequence*> sequences;
BatchedCopy batched_copy;
for (int i = first; i < last; ++i) {
input_ids = batched_copy.Add(input_d_ptrs[i], h_input_length_buf_[i], input_ids);
......@@ -1438,6 +1509,7 @@ bool LlamaBatch<T>::Forward(GenerationState& g, int iter)
}
decode_indices.push_back(i);
decode_lengths.push_back(h_input_length_buf_[i]);
sequences.push_back(state_->sequences[i]);
max_input_len = std::max(max_input_len, h_input_length_buf_[i]);
}
int token_count = input_ids - context_decoder_ids_buf_;
......@@ -1484,7 +1556,9 @@ bool LlamaBatch<T>::Forward(GenerationState& g, int iter)
pf_batch_size,
max_input_len,
max_context_cnts[p],
max_context_cnts[p]);
max_context_cnts[p],
h_input_length_buf_ + first,
sequences.data());
if (iter == 0) {
// compute logits of inputs if requested
......
......@@ -165,6 +165,33 @@ void LlamaV2<T>::embeddingLookup(T* embeddings, const int* token_ids_buf, int ba
sync_check_cuda_error();
}
template<typename T>
void LlamaV2<T>::updateEmbedding(T* decoder_input, const int bsz, const int* h_input_length, const Sequence** sequences)
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
for (int i = 0; i < bsz; i++) {
const auto& seq = *sequences[i];
const auto& embeddings = seq.input_embeddings;
const auto& ranges = seq.input_embedding_ranges;
for (int j = embeddings.size() - 1; j >= 0; j--) {
int begin = ranges[j].first;
int end = ranges[j].second;
if (end <= seq.cache_len) {
break;
}
int off_dst = std::max(0, begin - seq.cache_len);
int off_src = std::max(0, seq.cache_len - begin);
size_t byte_size = (end - begin) * hidden_units_ * sizeof(T);
T* dst_ptr = decoder_input + off_dst * hidden_units_;
auto src_ptr = embeddings[j].data() + off_src * hidden_units_ * sizeof(T);
cudaMemcpyAsync(dst_ptr, src_ptr, byte_size, cudaMemcpyDefault, stream_);
}
decoder_input += h_input_length[i] * hidden_units_;
}
sync_check_cuda_error();
}
template<typename T>
void LlamaV2<T>::forwardUnified(T* out,
T* decoder_output,
......@@ -187,7 +214,9 @@ void LlamaV2<T>::forwardUnified(T* out,
int pf_batch_size,
int pf_max_input_len,
int pf_max_context_len,
int pf_session_len)
int pf_session_len,
const int* h_input_length,
const Sequence** sequences)
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
......@@ -203,6 +232,9 @@ void LlamaV2<T>::forwardUnified(T* out,
1,
hidden_units_,
stream_);
updateEmbedding(decoder_input, dc_batch_size + pf_batch_size, h_input_length, sequences);
sync_check_cuda_error();
const auto dtype = getTensorType<T>();
......
......@@ -107,6 +107,8 @@ private:
void embeddingLookup(T* embeddings, const int* token_ids_buf, int batch_size, int step);
void updateEmbedding(T* decoder_input, const int bsz, const int* h_input_length, const Sequence** sequences);
void forwardUnified(T* out,
T* decoder_output,
T* decoder_input,
......@@ -128,7 +130,9 @@ private:
int pf_batch_size,
int pf_max_input_len,
int pf_max_context_len,
int pf_session_len);
int pf_session_len,
const int* h_input_length,
const Sequence** sequences);
void postDecodeEmbedding(float* logits, float* local_logits, const T* decoder_output, int batch_size);
......
......@@ -33,6 +33,10 @@ struct Sequence {
mutable float rope_theta = 0.f;
// embedding data
mutable std::vector<std::vector<std::byte>> input_embeddings;
mutable std::vector<std::pair<int, int>> input_embedding_ranges;
explicit Sequence(uint64_t _id): id(_id) {}
friend std::ostream& operator<<(std::ostream& os, const Sequence& seq);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment