Commit 31653dd9 authored by lizhigong's avatar lizhigong
Browse files

Merge branch 'v0.5.4_dev_niuhb' into 'v0.5.4_dev'

add pin_memory

See merge request OpenDAS/sglang!35
parents 875344ee 06b29699
......@@ -369,7 +369,7 @@ class ForwardBatch:
if batch.extend_input_logprob_token_ids is not None:
ret.extend_input_logprob_token_ids_gpu = (
batch.extend_input_logprob_token_ids.to(device, non_blocking=True)
batch.extend_input_logprob_token_ids.pin_memory().to(device, non_blocking=True)
)
if enable_num_token_non_padded(model_runner.server_args):
......@@ -425,10 +425,10 @@ class ForwardBatch:
assert isinstance(batch.extend_prefix_lens, list)
ret.extend_seq_lens = torch.tensor(
batch.extend_seq_lens, dtype=torch.int32
).to(device, non_blocking=True)
).pin_memory().to(device, non_blocking=True)
ret.extend_prefix_lens = torch.tensor(
batch.extend_prefix_lens, dtype=torch.int32
).to(device, non_blocking=True)
).pin_memory().to(device, non_blocking=True)
ret.extend_num_tokens = batch.extend_num_tokens
positions, ret.extend_start_loc = compute_position(
model_runner.server_args.attention_backend,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment