Commit 6f624c94 authored by PanZezhong's avatar PanZezhong
Browse files

issue/97 推理脚本支持batch

parent 0da7b5db
...@@ -10,6 +10,8 @@ ...@@ -10,6 +10,8 @@
#include <stdexcept> #include <stdexcept>
#include <utility> #include <utility>
#include <spdlog/spdlog.h>
namespace infinilm::cache { namespace infinilm::cache {
/** /**
......
...@@ -72,44 +72,38 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat ...@@ -72,44 +72,38 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
throw std::runtime_error("Unexpected position_ids shape"); throw std::runtime_error("Unexpected position_ids shape");
} }
// 4. Apply RoPE to full batch // 4. Process each batch item separately for attention computation
auto q_for_rope = q_reshaped->view({batch_size * seq_len, num_attention_heads_, head_dim_});
auto k_for_rope = k_reshaped->view({batch_size * seq_len, num_key_value_heads_, head_dim_});
// Call RoPE on full batch (matching Python pattern)
auto q_rope_out = rotary_emb_->forward(q_for_rope, pos_ids_for_rope);
auto k_rope_out = rotary_emb_->forward(k_for_rope, pos_ids_for_rope);
// Reshape back to [batch_size, seq_len, num_heads, head_dim] (matching Python pattern)
q_rope_out = q_rope_out->view({batch_size, seq_len, num_attention_heads_, head_dim_});
k_rope_out = k_rope_out->view({batch_size, seq_len, num_key_value_heads_, head_dim_});
// 5. Process each batch item separately for attention computation
infinilm::cache::KVCache *external_cache = static_cast<infinilm::cache::KVCache *>(kv_cache); infinilm::cache::KVCache *external_cache = static_cast<infinilm::cache::KVCache *>(kv_cache);
// Convert to [batch, n_head, seq_len, head_dim] for cache // Convert to [batch, n_head, seq_len, head_dim] for cache
// Ensure contiguous after permute for F16 compatibility with cache operations // Ensure contiguous after permute for F16 compatibility with cache operations
auto q_rope = q_rope_out->permute({0, 2, 1, 3})->contiguous(); // [bs, n_q_head, seq_len, head_dim] q_reshaped = q_reshaped->permute({0, 2, 1, 3})->contiguous(); // [bs, n_q_head, seq_len, head_dim]
auto k_rope = k_rope_out->permute({0, 2, 1, 3}); // [bs, n_kv_head, seq_len, head_dim] auto k_permuted = k_reshaped->permute({0, 2, 1, 3}); // [bs, n_kv_head, seq_len, head_dim]
auto v_permuted = v_reshaped->permute({0, 2, 1, 3}); // [bs, n_kv_head, seq_len, head_dim] auto v_permuted = v_reshaped->permute({0, 2, 1, 3}); // [bs, n_kv_head, seq_len, head_dim]
// 5. Prepare KV caches // 4. Prepare KV caches
infinicore::Tensor k_total; // [bs, n_kv_head, total_seq_len, head_dim] infinicore::Tensor k_total; // [bs, n_kv_head, total_seq_len, head_dim]
infinicore::Tensor v_total; // [bs, n_kv_head, total_seq_len, head_dim] infinicore::Tensor v_total; // [bs, n_kv_head, total_seq_len, head_dim]
if (external_cache != nullptr) { if (external_cache != nullptr) {
auto [k_total_tmp, v_total_tmp] = external_cache->update(k_rope, v_permuted); auto [k_total_tmp, v_total_tmp] = external_cache->update(k_permuted, v_permuted);
k_total = k_total_tmp; k_total = k_total_tmp;
v_total = v_total_tmp; v_total = v_total_tmp;
} else { } else {
auto [k_total_tmp, v_total_tmp] = internal_cache_.update(k_rope, v_permuted); auto [k_total_tmp, v_total_tmp] = internal_cache_.update(k_permuted, v_permuted);
k_total = k_total_tmp; k_total = k_total_tmp;
v_total = v_total_tmp; v_total = v_total_tmp;
} }
auto total_seq_len = k_total->shape()[2]; auto total_seq_len = k_total->shape()[2];
// 5. Apply RoPE to full batch
auto q_rope = q_reshaped->view({batch_size * num_attention_heads_, seq_len, head_dim_})->permute({1, 0, 2}); // [seq_len, bs * n_q_head, head_dim]
auto k_rope = k_total->narrow({{2, total_seq_len - seq_len, seq_len}})->view({batch_size * num_key_value_heads_, seq_len, head_dim_})->permute({1, 0, 2}); // [seq_len, bs * n_kv_head, head_dim]
rotary_emb_->forward(q_rope, pos_ids_for_rope, true);
rotary_emb_->forward(k_rope, pos_ids_for_rope, true);
// 6. Compute attention // 6. Compute attention
size_t ngroup = num_attention_heads_ / num_key_value_heads_; size_t ngroup = num_attention_heads_ / num_key_value_heads_;
auto Q = q_rope->view({batch_size * num_key_value_heads_, ngroup * seq_len, head_dim_}); auto Q = q_reshaped->view({batch_size * num_key_value_heads_, ngroup * seq_len, head_dim_});
auto K = k_total->view({batch_size * num_key_value_heads_, total_seq_len, head_dim_}); auto K = k_total->view({batch_size * num_key_value_heads_, total_seq_len, head_dim_});
auto V = v_total->view({batch_size * num_key_value_heads_, total_seq_len, head_dim_}); auto V = v_total->view({batch_size * num_key_value_heads_, total_seq_len, head_dim_});
......
...@@ -63,11 +63,23 @@ def get_args(): ...@@ -63,11 +63,23 @@ def get_args():
default="float32", default="float32",
help="float32, float16, bfloat16", help="float32, float16, bfloat16",
) )
parser.add_argument(
"--batch_size",
type=int,
default=1,
help="number of prompts in a batch",
)
parser.add_argument(
"--prompt",
type=str,
default="How are you",
help="input prompt",
)
return parser.parse_args() return parser.parse_args()
def test( def test(
prompt, prompts: str | list[str],
model_path, model_path,
max_new_tokens=100, max_new_tokens=100,
infini_dtype=infinicore.bfloat16, infini_dtype=infinicore.bfloat16,
...@@ -123,18 +135,24 @@ def test( ...@@ -123,18 +135,24 @@ def test(
# token编码 # token编码
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
# prompt = "山东最高的山是?" # prompt = "山东最高的山是?"
input_content = tokenizer.apply_chat_template( if isinstance(prompts, str):
conversation=[{"role": "user", "content": prompt}], prompts = [prompts]
add_generation_prompt=True, input_contents = [
tokenize=False, tokenizer.apply_chat_template(
) conversation=[{"role": "user", "content": prompt}],
print(input_content, end="", flush=True) add_generation_prompt=True,
input_ids = tokenizer.encode(input_content) tokenize=False,
)
for prompt in prompts
]
print(input_contents[0], end="", flush=True)
input_ids_list = tokenizer.batch_encode_plus(input_contents)[
"input_ids"
] # List: [[1, 1128, 526, 366, 29892]]
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
# 自回归生成 # 自回归生成
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
input_ids_list = [input_ids] # List: [[1, 1128, 526, 366, 29892]]
input_ids_infini = infinicore.from_list(input_ids_list) input_ids_infini = infinicore.from_list(input_ids_list)
t1 = time.time() t1 = time.time()
...@@ -175,7 +193,7 @@ if __name__ == "__main__": ...@@ -175,7 +193,7 @@ if __name__ == "__main__":
"such as, python examples/llama.py --nvidia --model_path=~/TinyLlama-1.1B-Chat-v1.0" "such as, python examples/llama.py --nvidia --model_path=~/TinyLlama-1.1B-Chat-v1.0"
) )
sys.exit(1) sys.exit(1)
prompt = "How are you" prompts = [args.prompt for _ in range(args.batch_size)]
model_path = args.model_path model_path = args.model_path
max_new_tokens = args.max_new_tokens max_new_tokens = args.max_new_tokens
...@@ -192,7 +210,7 @@ if __name__ == "__main__": ...@@ -192,7 +210,7 @@ if __name__ == "__main__":
raise ValueError(f"Unsupported dtype: {args.dtype}") raise ValueError(f"Unsupported dtype: {args.dtype}")
test( test(
prompt, prompts,
model_path, model_path,
max_new_tokens, max_new_tokens,
infini_device=infini_device, infini_device=infini_device,
......
...@@ -100,9 +100,11 @@ class GenerationMixin: ...@@ -100,9 +100,11 @@ class GenerationMixin:
# -------------------------------------------------------------------- # # -------------------------------------------------------------------- #
# 所需的: token的input_ids # 所需的: token的input_ids
# -------------------------------------------------------------------- # # -------------------------------------------------------------------- #
if kwargs.get("next_token_id", None) is not None: if kwargs.get("next_token_ids", None) is not None:
next_token_id = kwargs["next_token_id"] next_token_ids = kwargs["next_token_ids"]
model_inputs["input_ids"] = infinicore.from_list([[next_token_id]]) model_inputs["input_ids"] = infinicore.from_list(
[[id_] for id_ in next_token_ids],
)
# -------------------------------------------------------------------- # # -------------------------------------------------------------------- #
# 其他 # 其他
...@@ -236,7 +238,7 @@ class GenerationMixin: ...@@ -236,7 +238,7 @@ class GenerationMixin:
token_id = next_tokens.to_numpy()[0] token_id = next_tokens.to_numpy()[0]
output_str = tokenizer.decode([token_id], skip_special_tokens=True) output_str = tokenizer.decode([token_id], skip_special_tokens=True)
model_kwargs["next_token_id"] = token_id model_kwargs["next_token_ids"] = next_tokens.to_numpy().tolist()
output_tokens_list.append(token_id) output_tokens_list.append(token_id)
output_content += output_str output_content += output_str
...@@ -245,11 +247,16 @@ class GenerationMixin: ...@@ -245,11 +247,16 @@ class GenerationMixin:
break break
print("\n</s>") print("\n</s>")
print(f"\n\n\n Generation completed in {round(sum(time_list),2)} ms")
print( print(
f"\n\n\n Time per step: prefill {round(time_list[0], 2)} ms/token\n", f" Batchsize={batch_size} Per_Batch_Input_Len={seq_len} Per_Batch_New_Tokens={len(time_list)}\n"
) )
print( print(
f" Time per step: decoder {round(sum(time_list[1:]) / (len(time_list) - 1), 2)} ms/token \n", f" Prefill TTFT: {round(time_list[0], 2)}ms Throughput: {round((1000 * batch_size * seq_len)/time_list[0], 2)}tok/s\n",
) )
if len(time_list) > 1:
print(
f" Decode Avg ITL: {round(sum(time_list[1:]) / (len(time_list) - 1), 2)}ms Throughput: {round((1000 * batch_size * (len(time_list) - 1))/ sum(time_list[1:]), 2)}tok/s\n",
)
return output_tokens_list, output_content return output_tokens_list, output_content
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment