Commit 0ea1cd55 authored by PanZezhong's avatar PanZezhong
Browse files

issue/248 fix total seqlen to cpu as int32

parent 8297a0b7
......@@ -101,7 +101,7 @@ StaticKVCache::update(size_t layer_idx,
v,
past_sequence_lengths);
#else
size_t cache_pos = reinterpret_cast<int64_t *>(past_sequence_lengths->to(infinicore::Device::cpu())->data())[0];
size_t cache_pos = reinterpret_cast<int32_t *>(past_sequence_lengths->to(infinicore::Device::cpu())->data())[0];
auto result_len = cache_pos + update_len;
ASSERT(result_len <= cache_len_);
......@@ -213,9 +213,9 @@ PagedKVCache::get_contiguous_kv(
const infinicore::Tensor cache_lens,
const infinicore::Tensor input_offsets,
size_t request_id) {
ASSERT_EQ(block_tables->dtype(), infinicore::DataType::I64);
ASSERT_EQ(cache_lens->dtype(), infinicore::DataType::I64);
ASSERT_EQ(input_offsets->dtype(), infinicore::DataType::I64);
ASSERT_EQ(block_tables->dtype(), infinicore::DataType::I32);
ASSERT_EQ(cache_lens->dtype(), infinicore::DataType::I32);
ASSERT_EQ(input_offsets->dtype(), infinicore::DataType::I32);
auto nreq = block_tables->size(0);
auto block_tables_cpu = block_tables->to(infinicore::Device::cpu());
......@@ -227,9 +227,9 @@ PagedKVCache::get_contiguous_kv(
auto &&[k_cache_layer, v_cache_layer] = get_paged_kv(layer_idx);
auto req = request_id;
auto cache_lens_ptr = reinterpret_cast<const int64_t *>(cache_lens_cpu->data());
auto input_offsets_ptr = reinterpret_cast<const int64_t *>(input_offsets_cpu->data());
int64_t total_len = cache_lens_ptr[req] + (input_offsets_ptr[req + 1] - input_offsets_ptr[req]);
auto cache_lens_ptr = reinterpret_cast<const int32_t *>(cache_lens_cpu->data());
auto input_offsets_ptr = reinterpret_cast<const int32_t *>(input_offsets_cpu->data());
int32_t total_len = cache_lens_ptr[req] + (input_offsets_ptr[req + 1] - input_offsets_ptr[req]);
auto full_k = infinicore::Tensor::empty(
{num_rank_k_heads_, (size_t)total_len, k_dim_},
......@@ -243,7 +243,7 @@ PagedKVCache::get_contiguous_kv(
size_t r = total_len % block_size_;
for (size_t b = 0; b < nblocks; b++) {
size_t bid = *((int64_t *)(block_tables_cpu->narrow({{0, req, 1}, {1, b, 1}})->data()));
size_t bid = *((int32_t *)(block_tables_cpu->narrow({{0, req, 1}, {1, b, 1}})->data()));
full_k->narrow({{1, b * block_size_, block_size_}})
->copy_from(k_cache_layer->narrow({{0, bid, 1}})->squeeze(0));
......@@ -252,7 +252,7 @@ PagedKVCache::get_contiguous_kv(
}
if (r > 0) {
size_t bid = *((int64_t *)(block_tables_cpu->narrow({{0, req, 1}, {1, nblocks, 1}})->data()));
size_t bid = *((int32_t *)(block_tables_cpu->narrow({{0, req, 1}, {1, nblocks, 1}})->data()));
full_k->narrow({{1, nblocks * block_size_, r}})
->copy_from(k_cache_layer->narrow({{0, bid, 1}})->squeeze(0)->narrow({{1, 0, r}}));
......
......@@ -209,7 +209,7 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta
->contiguous()
->view({batch_size, seq_len, num_attention_heads_ * head_dim_}); // [bs, seq_len, n_q_head * head_dim]
} else {
size_t total_seq_len = reinterpret_cast<int64_t *>(total_sequence_lengths.value()->to(infinicore::Device::cpu())->data())[0];
size_t total_seq_len = reinterpret_cast<int32_t *>(total_sequence_lengths.value()->to(infinicore::Device::cpu())->data())[0];
k_total = k_total->narrow({{2, 0, total_seq_len}}); // [bs, n_kv_head, total_seq_len, head_dim]
v_total = v_total->narrow({{2, 0, total_seq_len}}); // [bs, n_kv_head, total_seq_len, head_dim]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment