Commit e9cfa85e authored by jujl1's avatar jujl1
Browse files

fix: update_state,优化性能,去除冗余操作

parent be41974c
...@@ -513,7 +513,9 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -513,7 +513,9 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# Add the sampled token(s) from the previous step (if any). # Add the sampled token(s) from the previous step (if any).
# This doesn't include "unverified" tokens like spec tokens. # This doesn't include "unverified" tokens like spec tokens.
num_new_tokens = len(new_token_ids) num_new_tokens = len(new_token_ids)
if num_new_tokens > 0: if num_new_tokens == 1:
req_state.output_token_ids.append(new_token_ids[-1])
elif num_new_tokens > 0:
req_state.output_token_ids.extend( req_state.output_token_ids.extend(
new_token_ids) new_token_ids)
if len(spec_token_ids) > 0: if len(spec_token_ids) > 0:
...@@ -535,11 +537,6 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -535,11 +537,6 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# The request is not in the persistent batch. # The request is not in the persistent batch.
# The request was either preempted and resumed later, or was not # The request was either preempted and resumed later, or was not
# scheduled in the previous step and needs to be added again. # scheduled in the previous step and needs to be added again.
if not is_last_rank:
req_state = self.requests[req_id]
self.input_batch.add_request(req_state)
req_index = self.input_batch.req_id_to_index.get(req_id)
else:
req_ids_to_add.append(req_id) req_ids_to_add.append(req_id)
continue continue
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment