Unverified Commit 36f8eab7 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #98 from InfiniTensor/issue/97

issue/97 Attention 和 KVCache 支持 batch 维度
parents 42f9d47d 6f624c94
#pragma once #pragma once
#include "infinicore/context/context.hpp"
#include "infinicore/device.hpp" #include "infinicore/device.hpp"
#include "infinicore/tensor.hpp" #include "infinicore/tensor.hpp"
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <numeric>
#include <stdexcept>
#include <utility> #include <utility>
#include <spdlog/spdlog.h>
namespace infinilm::cache { namespace infinilm::cache {
/** /**
* @brief Simple KV cache structure for incremental decoding * @brief Simple KV cache structure for incremental decoding
* *
* Stores key and value caches with shape [n_kv_head, capacity, head_dim] * Stores key and value caches with shape [batch_size, n_kv_head, capacity, head_dim]
* Similar to DynamicLayer in Python cache_utils.py * Similar to DynamicLayer in Python cache_utils.py
* *
* This is a common component that can be used by any model architecture * This is a common component that can be used by any model architecture
* that needs KV caching for attention mechanisms. * that needs KV caching for attention mechanisms.
*/ */
struct KVCache { struct KVCache {
infinicore::Tensor k_cache; // [n_kv_head, capacity, head_dim] infinicore::Tensor k_cache; // [batch_size, n_kv_head, capacity, head_dim]
infinicore::Tensor v_cache; // [n_kv_head, capacity, head_dim] infinicore::Tensor v_cache; // [batch_size, n_kv_head, capacity, head_dim]
size_t cache_position; // Current position in cache std::vector<size_t> cache_positions; // Current position in cache
size_t max_capacity; // Maximum capacity of cache size_t max_capacity; // Maximum capacity of cache
bool initialized; // Whether cache has been initialized bool initialized; // Whether cache has been initialized
KVCache() KVCache() : max_capacity(0), initialized(false) {}
: cache_position(0), max_capacity(0), initialized(false),
// Create empty placeholder tensors (will be replaced on first use)
k_cache(infinicore::Tensor::empty({1, 1, 1}, infinicore::DataType::F32,
infinicore::Device(infinicore::Device::Type::CPU, 0))),
v_cache(infinicore::Tensor::empty({1, 1, 1}, infinicore::DataType::F32,
infinicore::Device(infinicore::Device::Type::CPU, 0))) {}
/** /**
* @brief Initialize or update cache capacity * @brief Initialize or update cache capacity
...@@ -40,34 +40,44 @@ struct KVCache { ...@@ -40,34 +40,44 @@ struct KVCache {
* @param dtype Data type * @param dtype Data type
* @param device Device * @param device Device
*/ */
void ensure_capacity(size_t num_kv_heads, size_t head_dim, size_t seq_len, void ensure_capacity(size_t batch_size, size_t num_kv_heads, size_t head_dim, size_t seq_len,
infinicore::DataType dtype, const infinicore::Device &device) { infinicore::DataType dtype, const infinicore::Device &device) {
size_t required_capacity = cache_position + seq_len; size_t required_capacity = seq_len + std::accumulate(cache_positions.begin(), cache_positions.end(), 0, [](int a, int b) { return std::max(a, b); });
// Lazy initialization // Lazy initialization
if (!initialized) { if (!initialized) {
max_capacity = std::max(required_capacity, size_t(4096)); // Start with at least 4096 max_capacity = std::max(required_capacity, size_t(4096)); // Start with at least 4096
k_cache = infinicore::Tensor::empty({num_kv_heads, max_capacity, head_dim}, k_cache = infinicore::Tensor::empty({batch_size, num_kv_heads, max_capacity, head_dim},
dtype, device); dtype, device);
v_cache = infinicore::Tensor::empty({num_kv_heads, max_capacity, head_dim}, v_cache = infinicore::Tensor::empty({batch_size, num_kv_heads, max_capacity, head_dim},
dtype, device); dtype, device);
cache_position = 0; cache_positions = std::vector<size_t>(batch_size, 0);
initialized = true; initialized = true;
} }
// Grow cache if needed (similar to DynamicLayer in Python) // Grow cache if needed (similar to DynamicLayer in Python)
else if (required_capacity > max_capacity) { else if (required_capacity > max_capacity) {
size_t new_capacity = std::max(max_capacity * 2, required_capacity); size_t new_capacity = std::max(max_capacity * 2, required_capacity + max_capacity);
auto k_new = infinicore::Tensor::empty({num_kv_heads, new_capacity, head_dim}, size_t new_batch_size = std::max(batch_size, k_cache->shape()[0]);
if (num_kv_heads != k_cache->shape()[1] || head_dim != k_cache->shape()[3]) {
throw std::runtime_error("KVCache ensure_capacity: num_kv_heads or head_dim mismatch with existing cache.");
}
if (new_batch_size > cache_positions.size()) {
cache_positions.resize(new_batch_size, 0);
}
auto k_new = infinicore::Tensor::empty({new_batch_size, num_kv_heads, new_capacity, head_dim},
dtype, device); dtype, device);
auto v_new = infinicore::Tensor::empty({num_kv_heads, new_capacity, head_dim}, auto v_new = infinicore::Tensor::empty({new_batch_size, num_kv_heads, new_capacity, head_dim},
dtype, device); dtype, device);
// Copy existing cache data // Copy existing cache data
if (cache_position > 0) { for (size_t b = 0; b < new_batch_size; ++b) {
auto k_slice = k_cache->narrow({{1, 0, cache_position}}); size_t cache_position = cache_positions[b];
auto v_slice = v_cache->narrow({{1, 0, cache_position}}); if (cache_position > 0) {
k_new->narrow({{1, 0, cache_position}})->copy_from(k_slice); auto k_slice = k_cache->narrow({{0, b, 1}, {2, 0, cache_position}});
v_new->narrow({{1, 0, cache_position}})->copy_from(v_slice); auto v_slice = v_cache->narrow({{0, b, 1}, {2, 0, cache_position}});
k_new->narrow({{0, b, 1}, {2, 0, cache_position}})->copy_from(k_slice);
v_new->narrow({{0, b, 1}, {2, 0, cache_position}})->copy_from(v_slice);
}
} }
k_cache = k_new; k_cache = k_new;
...@@ -76,10 +86,16 @@ struct KVCache { ...@@ -76,10 +86,16 @@ struct KVCache {
} }
} }
KVCache(size_t max_batch_size, size_t n_kv_head, size_t head_dim, infinicore::DataType dtype, size_t max_seqlen = 4096, infinicore::Device device = infinicore::context::getDevice())
: max_capacity(max_seqlen), initialized(false) {
cache_positions = std::vector<size_t>(max_batch_size, 0);
ensure_capacity(max_batch_size, n_kv_head, head_dim, max_capacity, dtype, device);
}
/** /**
* @brief Update cache with new key and value states * @brief Update cache with new key and value states
* @param k_new New key states [n_kv_head, seq_len, head_dim] * @param k_new New key states [batch_size, n_kv_head, seq_len, head_dim]
* @param v_new New value states [n_kv_head, seq_len, head_dim] * @param v_new New value states [batch_size, n_kv_head, seq_len, head_dim]
* @return Tuple of (k_total, v_total) with shape [n_kv_head, total_seq_len, head_dim] * @return Tuple of (k_total, v_total) with shape [n_kv_head, total_seq_len, head_dim]
* *
* Note: This method writes to the cache. If using with attention op, the attention op * Note: This method writes to the cache. If using with attention op, the attention op
...@@ -88,28 +104,42 @@ struct KVCache { ...@@ -88,28 +104,42 @@ struct KVCache {
std::pair<infinicore::Tensor, infinicore::Tensor> update( std::pair<infinicore::Tensor, infinicore::Tensor> update(
const infinicore::Tensor &k_new, const infinicore::Tensor &k_new,
const infinicore::Tensor &v_new) { const infinicore::Tensor &v_new) {
size_t seq_len = k_new->shape()[1]; if (k_new->ndim() != 4 || v_new->ndim() != 4) {
size_t num_kv_heads = k_new->shape()[0]; throw std::runtime_error("KVCache update: k_new and v_new must be 4D tensors in [batch_size, n_kv_head, seq_len, head_dim] form.");
size_t head_dim = k_new->shape()[2]; }
size_t batch_size = k_new->shape()[0];
size_t num_kv_heads = k_new->shape()[1];
size_t seq_len = k_new->shape()[2];
size_t head_dim = k_new->shape()[3];
// Ensure capacity // Ensure capacity
ensure_capacity(num_kv_heads, head_dim, seq_len, ensure_capacity(batch_size, num_kv_heads, head_dim, seq_len,
k_new->dtype(), k_new->device()); k_new->dtype(), k_new->device());
// Copy new k/v into cache at current position // Copy new k/v into cache at current position
auto k_dst = k_cache->narrow({{1, cache_position, seq_len}}); bool all_equal = cache_positions.empty() || std::equal(cache_positions.begin() + 1, cache_positions.end(), cache_positions.begin());
auto v_dst = v_cache->narrow({{1, cache_position, seq_len}}); if (all_equal) {
k_dst->copy_from(k_new); auto cache_position = cache_positions[0];
v_dst->copy_from(v_new);
// Update position auto k_dst = k_cache->narrow({{2, cache_position, seq_len}});
cache_position += seq_len; auto v_dst = v_cache->narrow({{2, cache_position, seq_len}});
k_dst->copy_from(k_new);
v_dst->copy_from(v_new);
// Return the total cache up to current position // Update position
auto k_total = k_cache->narrow({{1, 0, cache_position}}); cache_position += seq_len;
auto v_total = v_cache->narrow({{1, 0, cache_position}}); for (size_t b = 0; b < batch_size; ++b) {
cache_positions[b] = cache_position;
}
return std::make_pair(k_total->contiguous(), v_total->contiguous()); // Return the total cache up to current position
auto k_total = k_cache->narrow({{2, 0, cache_position}});
auto v_total = v_cache->narrow({{2, 0, cache_position}});
return std::make_pair(k_total, v_total);
} else {
throw std::runtime_error("KVCache update: cache positions must be equal among a batch.");
}
} }
}; };
......
...@@ -72,128 +72,57 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat ...@@ -72,128 +72,57 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
throw std::runtime_error("Unexpected position_ids shape"); throw std::runtime_error("Unexpected position_ids shape");
} }
// 4. Apply RoPE to full batch - align with Python pattern // 4. Process each batch item separately for attention computation
infinilm::cache::KVCache *external_cache = static_cast<infinilm::cache::KVCache *>(kv_cache);
// Python: x = x.view((bs * seq_len, num_heads, head_dim)) // Convert to [batch, n_head, seq_len, head_dim] for cache
// Python asserts: seq_len * x_stride[1] == x_stride[0] (contiguous in dim=0 and dim=1) // Ensure contiguous after permute for F16 compatibility with cache operations
// The kernel requires stride(2) == 1 (last dimension contiguous) q_reshaped = q_reshaped->permute({0, 2, 1, 3})->contiguous(); // [bs, n_q_head, seq_len, head_dim]
// Python's assertion + stride(2) == 1 means the tensor is fully contiguous auto k_permuted = k_reshaped->permute({0, 2, 1, 3}); // [bs, n_kv_head, seq_len, head_dim]
// However, to be safe and match Python's behavior exactly, ensure fully contiguous auto v_permuted = v_reshaped->permute({0, 2, 1, 3}); // [bs, n_kv_head, seq_len, head_dim]
auto q_for_rope = q_reshaped->view({batch_size * seq_len, num_attention_heads_, head_dim_})->contiguous();
auto k_for_rope = k_reshaped->view({batch_size * seq_len, num_key_value_heads_, head_dim_})->contiguous(); // 4. Prepare KV caches
infinicore::Tensor k_total; // [bs, n_kv_head, total_seq_len, head_dim]
infinicore::Tensor v_total; // [bs, n_kv_head, total_seq_len, head_dim]
if (external_cache != nullptr) {
auto [k_total_tmp, v_total_tmp] = external_cache->update(k_permuted, v_permuted);
k_total = k_total_tmp;
v_total = v_total_tmp;
} else {
auto [k_total_tmp, v_total_tmp] = internal_cache_.update(k_permuted, v_permuted);
k_total = k_total_tmp;
v_total = v_total_tmp;
}
auto total_seq_len = k_total->shape()[2];
// Call RoPE on full batch (matching Python pattern) // 5. Apply RoPE to full batch
auto q_rope_out = rotary_emb_->forward(q_for_rope, pos_ids_for_rope); auto q_rope = q_reshaped->view({batch_size * num_attention_heads_, seq_len, head_dim_})->permute({1, 0, 2}); // [seq_len, bs * n_q_head, head_dim]
auto k_rope_out = rotary_emb_->forward(k_for_rope, pos_ids_for_rope); auto k_rope = k_total->narrow({{2, total_seq_len - seq_len, seq_len}})->view({batch_size * num_key_value_heads_, seq_len, head_dim_})->permute({1, 0, 2}); // [seq_len, bs * n_kv_head, head_dim]
rotary_emb_->forward(q_rope, pos_ids_for_rope, true);
rotary_emb_->forward(k_rope, pos_ids_for_rope, true);
// Reshape back to [batch_size, seq_len, num_heads, head_dim] (matching Python pattern) // 6. Compute attention
q_rope_out = q_rope_out->view({batch_size, seq_len, num_attention_heads_, head_dim_}); size_t ngroup = num_attention_heads_ / num_key_value_heads_;
k_rope_out = k_rope_out->view({batch_size, seq_len, num_key_value_heads_, head_dim_}); auto Q = q_reshaped->view({batch_size * num_key_value_heads_, ngroup * seq_len, head_dim_});
auto K = k_total->view({batch_size * num_key_value_heads_, total_seq_len, head_dim_});
auto V = v_total->view({batch_size * num_key_value_heads_, total_seq_len, head_dim_});
// 5. Process each batch item separately for attention computation auto K_transposed = K->permute({0, 2, 1}); // [bs * n_kv_head, head_dim, total_seq_len]
infinilm::cache::KVCache *external_cache = static_cast<infinilm::cache::KVCache *>(kv_cache);
auto output_tensor = infinicore::Tensor::empty( float scaling = 1.0f / std::sqrt(static_cast<float>(head_dim_));
{batch_size, seq_len, hidden_size_}, auto attn_weight = infinicore::op::matmul(Q, K_transposed, scaling); // [bs * n_kv_head, ng * seq_len, total_seq_len]
q->dtype(),
q->device()); auto attn_weight_softmax = attn_weight->view({batch_size * num_attention_heads_, seq_len, total_seq_len});
infinicore::op::causal_softmax_(attn_weight_softmax, attn_weight_softmax);
for (size_t b = 0; b < batch_size; ++b) {
// Extract batch item from RoPE output (already computed above for full batch) auto out = infinicore::op::matmul(attn_weight, V); // [bs * n_kv_head, ng * seq_len, head_dim]
// Ensure contiguous after narrow+view to avoid stride issues in GEMM operations
auto q_batch = q_rope_out->narrow({{0, b, 1}})->view({seq_len, num_attention_heads_, head_dim_}); auto attn_output = out->view({batch_size, num_attention_heads_, seq_len, head_dim_})
auto k_batch = k_rope_out->narrow({{0, b, 1}})->view({seq_len, num_key_value_heads_, head_dim_}); ->permute({0, 2, 1, 3})
auto v_batch = v_reshaped->narrow({{0, b, 1}})->view({seq_len, num_key_value_heads_, head_dim_}); ->contiguous()
->view({batch_size, seq_len, num_attention_heads_ * head_dim_}); // [bs, seq_len, n_q_head * head_dim]
// Convert to [n_head, seq_len, head_dim] for cache
// Ensure contiguous after permute for F16 compatibility with cache operations
auto q_rope = q_batch->permute({1, 0, 2})->contiguous(); // [n_q_head, seq_len, head_dim]
auto k_rope = k_batch->permute({1, 0, 2})->contiguous(); // [n_kv_head, seq_len, head_dim]
auto v_permuted = v_batch->permute({1, 0, 2})->contiguous(); // [n_kv_head, seq_len, head_dim]
// 5. Prepare KV caches
infinicore::Tensor k_total;
infinicore::Tensor v_total;
if (external_cache != nullptr) {
auto [k_total_tmp, v_total_tmp] = external_cache->update(k_rope, v_permuted);
k_total = k_total_tmp;
v_total = v_total_tmp;
} else {
auto [k_total_tmp, v_total_tmp] = internal_cache_.update(k_rope, v_permuted);
k_total = k_total_tmp;
v_total = v_total_tmp;
}
// 6. Compute attention - strictly align with Python pattern
// Python: query_states_i = query_states.narrow(0, i, 1).view((seq_len, num_attention_heads, head_dim))
// Python: key_states_i = key_states_total.narrow(0, i, 1).view((total_seq_len, num_key_value_heads, head_dim))
// Python: value_states_i = value_states_total.narrow(0, i, 1).view((total_seq_len, num_key_value_heads, head_dim))
// Python: attention_i = grouped_query_attention(query_states_i, key_states_i, value_states_i, scaling=self.scaling)
// Extract from KV cache (k_total and v_total are [n_kv_head, total_seq_len, head_dim])
// Python: key_states_total.narrow(0, i, 1).view((total_seq_len, num_key_value_heads, head_dim))
// Python's narrow+view ensures contiguous memory, so we need to ensure contiguous before permute
auto k_for_attn = k_total->permute({1, 0, 2}); // [total_seq_len, n_kv_head, head_dim]
auto v_for_attn = v_total->permute({1, 0, 2}); // [total_seq_len, n_kv_head, head_dim]
// q_batch is already [seq_len, n_q_head, head_dim] from above
auto q_for_attn = q_batch; // [seq_len, n_q_head, head_dim]
// Python: grouped_query_attention calls repeat_kv if ngroup > 1
// Python: repeat_kv expands [total_seq_len, num_key_value_heads, head_dim] -> [total_seq_len, num_attention_heads, head_dim]
size_t ngroup = num_attention_heads_ / num_key_value_heads_;
if (ngroup > 1) {
// Python: repeat_kv uses as_strided to expand
size_t total_seq_len = k_for_attn->shape()[0];
size_t n_kv_head = k_for_attn->shape()[1];
size_t head_dim = k_for_attn->shape()[2];
auto k_strides = k_for_attn->strides();
auto k_strided = k_for_attn->as_strided(
{total_seq_len, n_kv_head, ngroup, head_dim},
{k_strides[0], k_strides[1], 0, k_strides[2]});
k_for_attn = k_strided->contiguous()->view({total_seq_len, n_kv_head * ngroup, head_dim});
auto v_strides = v_for_attn->strides();
auto v_strided = v_for_attn->as_strided(
{total_seq_len, n_kv_head, ngroup, head_dim},
{v_strides[0], v_strides[1], 0, v_strides[2]});
v_for_attn = v_strided->contiguous()->view({total_seq_len, n_kv_head * ngroup, head_dim});
}
// Python: multi_head_attention(querys, keys, values, scaling)
// Python: Q = querys.permute((1, 0, 2)) # [num_heads, seq_len, head_dim]
// Python: K = keys # [total_seq_len, num_heads, head_dim] (NO permute!)
// Python: V = values.permute((1, 0, 2)) # [num_heads, total_seq_len, head_dim]
auto Q = q_for_attn->permute({1, 0, 2}); // [n_q_head, seq_len, head_dim]
auto K = k_for_attn; // [total_seq_len, n_q_head, head_dim] - keep as-is (matching Python)
auto V = v_for_attn->permute({1, 0, 2}); // [n_q_head, total_seq_len, head_dim]
// Python: attn_weight = Q @ K.permute((1, 2, 0))
// Python: K.permute((1, 2, 0)) transforms [total_seq_len, num_heads, head_dim] -> [num_heads, head_dim, total_seq_len]
auto K_transposed = K->permute({1, 2, 0}); // [n_q_head, head_dim, total_seq_len]
// Use GEMM with alpha=scaling to combine scaling with matrix multiplication
// This is more efficient than doing matmul followed by mul
float scaling = 1.0f / std::sqrt(static_cast<float>(head_dim_));
auto attn_weight = infinicore::op::matmul(Q, K_transposed, scaling); // [n_q_head, seq_len, total_seq_len]
infinicore::op::causal_softmax_(attn_weight, attn_weight);
auto out = infinicore::op::matmul(attn_weight, V); // [n_q_head, seq_len, head_dim]
// Python: return out.permute((1, 0, 2)).contiguous() # [seq_len, num_heads, head_dim]
auto attn_output = out->permute({1, 0, 2})->contiguous(); // [seq_len, n_q_head, head_dim]
// Python: attn_output_i.copy_(attention_i)
// Python: attn_output = attn_output.view(hidden_states_shape) # [bs, seq_len, hidden_size]
// Copy to output tensor - attn_output is [seq_len, num_attention_heads, head_dim]
auto output_batch = output_tensor->narrow({{0, b, 1}})->view({seq_len, hidden_size_});
auto attn_flat = attn_output->contiguous()->view({seq_len, hidden_size_});
output_batch->copy_from(attn_flat);
}
// 8. Apply output projection to all batches auto output = o_proj_->forward(attn_output);
auto output = o_proj_->forward(output_tensor);
return output; return output;
} }
......
...@@ -63,11 +63,23 @@ def get_args(): ...@@ -63,11 +63,23 @@ def get_args():
default="float32", default="float32",
help="float32, float16, bfloat16", help="float32, float16, bfloat16",
) )
parser.add_argument(
"--batch_size",
type=int,
default=1,
help="number of prompts in a batch",
)
parser.add_argument(
"--prompt",
type=str,
default="How are you",
help="input prompt",
)
return parser.parse_args() return parser.parse_args()
def test( def test(
prompt, prompts: str | list[str],
model_path, model_path,
max_new_tokens=100, max_new_tokens=100,
infini_dtype=infinicore.bfloat16, infini_dtype=infinicore.bfloat16,
...@@ -123,18 +135,24 @@ def test( ...@@ -123,18 +135,24 @@ def test(
# token编码 # token编码
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
# prompt = "山东最高的山是?" # prompt = "山东最高的山是?"
input_content = tokenizer.apply_chat_template( if isinstance(prompts, str):
conversation=[{"role": "user", "content": prompt}], prompts = [prompts]
add_generation_prompt=True, input_contents = [
tokenize=False, tokenizer.apply_chat_template(
) conversation=[{"role": "user", "content": prompt}],
print(input_content, end="", flush=True) add_generation_prompt=True,
input_ids = tokenizer.encode(input_content) tokenize=False,
)
for prompt in prompts
]
print(input_contents[0], end="", flush=True)
input_ids_list = tokenizer.batch_encode_plus(input_contents)[
"input_ids"
] # List: [[1, 1128, 526, 366, 29892]]
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
# 自回归生成 # 自回归生成
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
input_ids_list = [input_ids] # List: [[1, 1128, 526, 366, 29892]]
input_ids_infini = infinicore.from_list(input_ids_list) input_ids_infini = infinicore.from_list(input_ids_list)
t1 = time.time() t1 = time.time()
...@@ -175,7 +193,7 @@ if __name__ == "__main__": ...@@ -175,7 +193,7 @@ if __name__ == "__main__":
"such as, python examples/llama.py --nvidia --model_path=~/TinyLlama-1.1B-Chat-v1.0" "such as, python examples/llama.py --nvidia --model_path=~/TinyLlama-1.1B-Chat-v1.0"
) )
sys.exit(1) sys.exit(1)
prompt = "How are you" prompts = [args.prompt for _ in range(args.batch_size)]
model_path = args.model_path model_path = args.model_path
max_new_tokens = args.max_new_tokens max_new_tokens = args.max_new_tokens
...@@ -192,7 +210,7 @@ if __name__ == "__main__": ...@@ -192,7 +210,7 @@ if __name__ == "__main__":
raise ValueError(f"Unsupported dtype: {args.dtype}") raise ValueError(f"Unsupported dtype: {args.dtype}")
test( test(
prompt, prompts,
model_path, model_path,
max_new_tokens, max_new_tokens,
infini_device=infini_device, infini_device=infini_device,
......
...@@ -100,9 +100,11 @@ class GenerationMixin: ...@@ -100,9 +100,11 @@ class GenerationMixin:
# -------------------------------------------------------------------- # # -------------------------------------------------------------------- #
# 所需的: token的input_ids # 所需的: token的input_ids
# -------------------------------------------------------------------- # # -------------------------------------------------------------------- #
if kwargs.get("next_token_id", None) is not None: if kwargs.get("next_token_ids", None) is not None:
next_token_id = kwargs["next_token_id"] next_token_ids = kwargs["next_token_ids"]
model_inputs["input_ids"] = infinicore.from_list([[next_token_id]]) model_inputs["input_ids"] = infinicore.from_list(
[[id_] for id_ in next_token_ids],
)
# -------------------------------------------------------------------- # # -------------------------------------------------------------------- #
# 其他 # 其他
...@@ -236,7 +238,7 @@ class GenerationMixin: ...@@ -236,7 +238,7 @@ class GenerationMixin:
token_id = next_tokens.to_numpy()[0] token_id = next_tokens.to_numpy()[0]
output_str = tokenizer.decode([token_id], skip_special_tokens=True) output_str = tokenizer.decode([token_id], skip_special_tokens=True)
model_kwargs["next_token_id"] = token_id model_kwargs["next_token_ids"] = next_tokens.to_numpy().tolist()
output_tokens_list.append(token_id) output_tokens_list.append(token_id)
output_content += output_str output_content += output_str
...@@ -245,11 +247,16 @@ class GenerationMixin: ...@@ -245,11 +247,16 @@ class GenerationMixin:
break break
print("\n</s>") print("\n</s>")
print(f"\n\n\n Generation completed in {round(sum(time_list),2)} ms")
print( print(
f"\n\n\n Time per step: prefill {round(time_list[0], 2)} ms/token\n", f" Batchsize={batch_size} Per_Batch_Input_Len={seq_len} Per_Batch_New_Tokens={len(time_list)}\n"
) )
print( print(
f" Time per step: decoder {round(sum(time_list[1:]) / (len(time_list) - 1), 2)} ms/token \n", f" Prefill TTFT: {round(time_list[0], 2)}ms Throughput: {round((1000 * batch_size * seq_len)/time_list[0], 2)}tok/s\n",
) )
if len(time_list) > 1:
print(
f" Decode Avg ITL: {round(sum(time_list[1:]) / (len(time_list) - 1), 2)}ms Throughput: {round((1000 * batch_size * (len(time_list) - 1))/ sum(time_list[1:]), 2)}tok/s\n",
)
return output_tokens_list, output_content return output_tokens_list, output_content
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment