Unverified Commit 09deb20d authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Optimize the memory usage of logits processor (#420)

parent 33b242df
...@@ -98,7 +98,9 @@ class LogitsProcessor(nn.Module): ...@@ -98,7 +98,9 @@ class LogitsProcessor(nn.Module):
all_logits = tensor_model_parallel_all_gather(all_logits) all_logits = tensor_model_parallel_all_gather(all_logits)
all_logits = all_logits[:, : self.config.vocab_size] all_logits = all_logits[:, : self.config.vocab_size]
all_logprobs = torch.log(torch.softmax(all_logits.float(), dim=-1) + 1e-6) all_logprobs = all_logits.float()
all_logits = None
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs( prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
all_logprobs, input_metadata all_logprobs, input_metadata
......
...@@ -589,7 +589,7 @@ class ModelRpcServer: ...@@ -589,7 +589,7 @@ class ModelRpcServer:
+ len(req.output_ids) + len(req.output_ids)
- req.prompt_tokens, - req.prompt_tokens,
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward, "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
"finish_reason": req.finish_reason, "finish_reason": str(req.finish_reason),
"hit_stop_str": req.hit_stop_str, "hit_stop_str": req.hit_stop_str,
} }
if req.return_logprob: if req.return_logprob:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment