Commit df49168d authored by zhuwenwen's avatar zhuwenwen
Browse files

[fix] pp+mtp bs 1 correctness

parent f8cf43ae
...@@ -642,10 +642,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -642,10 +642,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if not is_last_rank: if not is_last_rank:
# Add new_token_ids to token_ids_cpu. # Add new_token_ids to token_ids_cpu.
start_token_index = num_computed_tokens start_token_index = num_computed_tokens
end_token_index = num_computed_tokens + len(new_token_ids) end_token_index = num_computed_tokens + 1
self.input_batch.token_ids_cpu[ self.input_batch.token_ids_cpu[
req_index, req_index,
start_token_index:end_token_index] = new_token_ids start_token_index:end_token_index] = new_token_ids[-1]
self.input_batch.num_tokens_no_spec[ self.input_batch.num_tokens_no_spec[
req_index] = end_token_index req_index] = end_token_index
self.input_batch.num_tokens[req_index] = end_token_index self.input_batch.num_tokens[req_index] = end_token_index
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment