Commit 9e434492 authored by PanZezhong's avatar PanZezhong
Browse files

issue/140 准确率脚本支持torch,添加总吞吐

parent 7862a723
......@@ -251,6 +251,7 @@ void RankWorker::thread_loop() {
} else if (local_cmd == Command::RUN) {
try {
auto out = model_->forward(local_args);
infinicore::context::syncStream();
{
std::lock_guard<std::mutex> lk(mutex_);
......
......@@ -34,13 +34,6 @@ infinicore::Tensor LlamaForCausalLM::forward(const infinicore::Tensor &input_ids
// 2. Apply language modeling head to get logits
auto logits = lm_head_->forward(hidden_states);
// 3. CRITICAL: Synchronize the C++ backend's context after forward pass
// This ensures all C++ backend operations complete before returning to Python
if (device_.getType() != infinicore::Device::Type::CPU) {
infinicore::context::setDevice(device_, false);
infinicore::context::syncStream();
}
return logits;
}
......
......@@ -251,22 +251,31 @@ class GenerationMixin:
output_content += output_str
end_time = time.time()
time_list.append((end_time - start_time) * 1000)
time_list.append((end_time - start_time))
print(output_str, end="", flush=True)
if stop_on_eos and token_id in eos_token_id_list:
break
print("\n</s>")
print(f"\n\n\n Generation completed in {round(sum(time_list), 2)} ms")
print(f"\n\n\n Generation completed in {round(sum(time_list) * 1000, 2)} ms")
print(
f" Batchsize={batch_size} Per_Batch_Input_Len={seq_len} Per_Batch_New_Tokens={len(time_list)}\n"
)
print(
f" Prefill TTFT: {round(time_list[0], 2)}ms Throughput: {round((1000 * batch_size * seq_len) / time_list[0], 2)}tok/s\n",
f" Prefill TTFT: {round(time_list[0], 2)}ms Throughput: {round((batch_size * seq_len) / time_list[0], 2)}tok/s\n",
)
if len(time_list) > 1:
print(
f" Decode Avg ITL: {round(sum(time_list[1:]) / (len(time_list) - 1), 2)}ms Throughput: {round((1000 * batch_size * (len(time_list) - 1)) / sum(time_list[1:]), 2)}tok/s\n",
f" Decode Avg ITL: {round(sum(time_list[1:]) * 1000 / (len(time_list) - 1), 2)}ms Throughput: {round((batch_size * (len(time_list) - 1)) / sum(time_list[1:]), 2)}tok/s\n",
)
return {
"output_token_ids": output_tokens_list,
"output_content": output_content,
"total_latency": sum(time_list),
"prefill_latency": time_list[0],
"decode_latency": sum(time_list[1:]),
"total_input_tokens": batch_size * seq_len,
"total_output_tokens": len(time_list),
}
return output_tokens_list, output_content
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment