Unverified Commit 87d9a261 authored by jiangkuaixue123's avatar jiangkuaixue123 Committed by GitHub
Browse files

[Bugfix] Fix ubatch wrapper num_tokens calculate (#33694)


Signed-off-by: default avatarjiangkuaixue123 <jiangxiaozhou111@163.com>
parent 80f921ba
......@@ -412,9 +412,7 @@ class UBatchWrapper:
attn_metadata = forward_context.attn_metadata
slot_mapping = forward_context.slot_mapping
num_tokens = (
ubatch_slices[0].token_slice.stop - ubatch_slices[0].token_slice.start
) * 2
num_tokens = sum(ubatch_slice.num_tokens for ubatch_slice in ubatch_slices)
input_ids = kwargs["input_ids"]
positions = kwargs["positions"]
intermediate_tensors = kwargs["intermediate_tensors"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment