Unverified Commit e84e0735 authored by Andrew Sansom's avatar Andrew Sansom Committed by GitHub
Browse files

fix: revert cast to cpu in `MsgpackEncoder._encode_tensor` to avoid hidden...


fix: revert cast to cpu in `MsgpackEncoder._encode_tensor` to avoid hidden performance regressions (#25738)
Signed-off-by: default avatarAndrew Sansom <andrew@protopia.ai>
parent 3edf87d2
...@@ -278,6 +278,11 @@ class InputPreprocessor: ...@@ -278,6 +278,11 @@ class InputPreprocessor:
raise ValueError( raise ValueError(
"prompt_embeds must be of shape (seq_len, hidden_size).") "prompt_embeds must be of shape (seq_len, hidden_size).")
# Tensors must be on CPU for serialization between processes
# in the MsgpackEncoder. Casting to CPU here ensures that there is no
# hidden device transfer in the critical path of generation.
prompt_embeds = prompt_embeds.cpu()
return embeds_inputs(prompt_embeds=prompt_embeds, return embeds_inputs(prompt_embeds=prompt_embeds,
cache_salt=parsed_content.get("cache_salt")) cache_salt=parsed_content.get("cache_salt"))
......
...@@ -208,7 +208,7 @@ class MsgpackEncoder: ...@@ -208,7 +208,7 @@ class MsgpackEncoder:
) -> tuple[str, tuple[int, ...], Union[int, memoryview]]: ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
assert self.aux_buffers is not None assert self.aux_buffers is not None
# view the tensor as a contiguous 1D array of bytes # view the tensor as a contiguous 1D array of bytes
arr = obj.flatten().contiguous().cpu().view(torch.uint8).numpy() arr = obj.flatten().contiguous().view(torch.uint8).numpy()
if obj.nbytes < self.size_threshold: if obj.nbytes < self.size_threshold:
# Smaller tensors are encoded inline, just like ndarrays. # Smaller tensors are encoded inline, just like ndarrays.
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data) data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment