Unverified Commit e37ff5b5 authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[Perf] Optimize token_embed for pooling models, 1.0% token throughput improvement (#37347)


Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
parent 6accb21f
...@@ -101,6 +101,7 @@ class PoolingMetadata: ...@@ -101,6 +101,7 @@ class PoolingMetadata:
num_scheduled_tokens_np: np.ndarray, num_scheduled_tokens_np: np.ndarray,
seq_lens_cpu: torch.Tensor, seq_lens_cpu: torch.Tensor,
device: torch.device, device: torch.device,
query_start_loc_gpu: torch.Tensor | None = None,
): ):
n_seq = len(num_scheduled_tokens_np) n_seq = len(num_scheduled_tokens_np)
prompt_lens = self.prompt_lens prompt_lens = self.prompt_lens
...@@ -109,11 +110,25 @@ class PoolingMetadata: ...@@ -109,11 +110,25 @@ class PoolingMetadata:
index = list(range(n_seq)) index = list(range(n_seq))
num_scheduled_tokens_cpu = torch.from_numpy(num_scheduled_tokens_np) num_scheduled_tokens_cpu = torch.from_numpy(num_scheduled_tokens_np)
cumsum = torch.zeros( if query_start_loc_gpu is None:
n_seq + 1, dtype=torch.int64, pin_memory=pin_memory, device="cpu" cumsum = torch.zeros(
) n_seq + 1, dtype=torch.int64, pin_memory=pin_memory, device="cpu"
torch.cumsum(num_scheduled_tokens_cpu, dim=0, out=cumsum[1:]) )
cumsum = cumsum.to(device, non_blocking=True) torch.cumsum(num_scheduled_tokens_cpu, dim=0, out=cumsum[1:])
cumsum = cumsum.to(device, non_blocking=True)
else:
if query_start_loc_gpu.shape[0] != n_seq + 1:
raise ValueError(
"query_start_loc_gpu length does not match "
f"the number of sequences: {query_start_loc_gpu.shape[0]} "
f"!= {n_seq + 1}."
)
if query_start_loc_gpu.device != device:
raise ValueError(
"query_start_loc_gpu must be on the same device as the "
f"hidden states: {query_start_loc_gpu.device} != {device}."
)
cumsum = query_start_loc_gpu
self.pooling_cursor = PoolingCursor( self.pooling_cursor = PoolingCursor(
index=index, index=index,
first_token_indices_gpu=cumsum[:n_seq], first_token_indices_gpu=cumsum[:n_seq],
......
...@@ -2928,7 +2928,10 @@ class GPUModelRunner( ...@@ -2928,7 +2928,10 @@ class GPUModelRunner(
pooling_metadata = self.input_batch.get_pooling_metadata() pooling_metadata = self.input_batch.get_pooling_metadata()
pooling_metadata.build_pooling_cursor( pooling_metadata.build_pooling_cursor(
num_scheduled_tokens_np, seq_lens_cpu, device=hidden_states.device num_scheduled_tokens_np,
seq_lens_cpu,
device=hidden_states.device,
query_start_loc_gpu=self.query_start_loc.gpu[: num_reqs + 1],
) )
model = cast(VllmModelForPooling, self.model) model = cast(VllmModelForPooling, self.model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment