"torchvision/vscode:/vscode.git/clone" did not exist on "bb88c4520b835e79d5d3c4423eb7ff7c26fa2043"
Commit 06b29699 authored by shangxl's avatar shangxl
Browse files

add pin_memory

parent 5533c538
...@@ -369,7 +369,7 @@ class ForwardBatch: ...@@ -369,7 +369,7 @@ class ForwardBatch:
if batch.extend_input_logprob_token_ids is not None: if batch.extend_input_logprob_token_ids is not None:
ret.extend_input_logprob_token_ids_gpu = ( ret.extend_input_logprob_token_ids_gpu = (
batch.extend_input_logprob_token_ids.to(device, non_blocking=True) batch.extend_input_logprob_token_ids.pin_memory().to(device, non_blocking=True)
) )
if enable_num_token_non_padded(model_runner.server_args): if enable_num_token_non_padded(model_runner.server_args):
...@@ -425,10 +425,10 @@ class ForwardBatch: ...@@ -425,10 +425,10 @@ class ForwardBatch:
assert isinstance(batch.extend_prefix_lens, list) assert isinstance(batch.extend_prefix_lens, list)
ret.extend_seq_lens = torch.tensor( ret.extend_seq_lens = torch.tensor(
batch.extend_seq_lens, dtype=torch.int32 batch.extend_seq_lens, dtype=torch.int32
).to(device, non_blocking=True) ).pin_memory().to(device, non_blocking=True)
ret.extend_prefix_lens = torch.tensor( ret.extend_prefix_lens = torch.tensor(
batch.extend_prefix_lens, dtype=torch.int32 batch.extend_prefix_lens, dtype=torch.int32
).to(device, non_blocking=True) ).pin_memory().to(device, non_blocking=True)
ret.extend_num_tokens = batch.extend_num_tokens ret.extend_num_tokens = batch.extend_num_tokens
positions, ret.extend_start_loc = compute_position( positions, ret.extend_start_loc = compute_position(
model_runner.server_args.attention_backend, model_runner.server_args.attention_backend,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment