Unverified Commit 84377e5d authored by Tanmay Verma's avatar Tanmay Verma Committed by GitHub
Browse files

fix: Fix the protocol in the example (#1146)

parent 35229c74
...@@ -103,6 +103,27 @@ class TRTLLMWorkerResponseOutput: ...@@ -103,6 +103,27 @@ class TRTLLMWorkerResponseOutput:
# the result of result_handler passed to postprocess workers # the result of result_handler passed to postprocess workers
_postprocess_result: Any = None _postprocess_result: Any = None
@property
def length(self) -> int:
return 0 if self.token_ids is None else len(self.token_ids)
@property
def text_diff(self) -> str:
return self.text[self._last_text_len :]
@property
def token_ids_diff(self) -> List[int]:
return (
[] if self.token_ids is None else self.token_ids[self._last_token_ids_len :]
)
# Ignoring the mypy error here as this is copied from TensorRT-LLM project.
# https://github.com/NVIDIA/TensorRT-LLM/blob/19c6e68bec891b66146a09647ee7b70230ef5f67/tensorrt_llm/executor/result.py#L68
# TODO: Work with the TensorRT-LLM team to get this fixed.
@property
def logprobs_diff(self) -> List[float]: # type: ignore
return [] if self.logprobs is None else self.logprobs[self._last_logprobs_len :] # type: ignore
class TRTLLMWorkerResponse(BaseModel): class TRTLLMWorkerResponse(BaseModel):
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment