Unverified Commit a31cab75 authored by Antoni Baum's avatar Antoni Baum Committed by GitHub
Browse files

[Core] Avoid copying prompt/output tokens if no penalties are used (#5289)

parent 828da0d4
...@@ -386,6 +386,7 @@ class SamplingTensors: ...@@ -386,6 +386,7 @@ class SamplingTensors:
presence_penalties += [0] * prefill_len presence_penalties += [0] * prefill_len
frequency_penalties += [0] * prefill_len frequency_penalties += [0] * prefill_len
repetition_penalties += [1] * prefill_len repetition_penalties += [1] * prefill_len
if do_penalties:
prompt_tokens.extend([] for _ in range(prefill_len)) prompt_tokens.extend([] for _ in range(prefill_len))
output_tokens.extend([] for _ in range(prefill_len)) output_tokens.extend([] for _ in range(prefill_len))
...@@ -394,6 +395,7 @@ class SamplingTensors: ...@@ -394,6 +395,7 @@ class SamplingTensors:
assert sample_lens == len(seq_ids) assert sample_lens == len(seq_ids)
for seq_id in seq_ids: for seq_id in seq_ids:
seq_data = seq_group.seq_data[seq_id] seq_data = seq_group.seq_data[seq_id]
if do_penalties:
prompt_tokens.append(seq_data.prompt_token_ids) prompt_tokens.append(seq_data.prompt_token_ids)
output_tokens.append(seq_data.output_token_ids) output_tokens.append(seq_data.output_token_ids)
temperatures += [temperature] * len(seq_ids) temperatures += [temperature] * len(seq_ids)
...@@ -443,6 +445,10 @@ class SamplingTensors: ...@@ -443,6 +445,10 @@ class SamplingTensors:
# Note that the performance will be very bad without # Note that the performance will be very bad without
# pinned memory. # pinned memory.
pin_memory = is_pin_memory_available() pin_memory = is_pin_memory_available()
do_penalties = prompt_tokens or output_tokens
if do_penalties:
prompt_max_len = max([len(tokens) for tokens in prompt_tokens], prompt_max_len = max([len(tokens) for tokens in prompt_tokens],
default=0) default=0)
prompt_padded_tokens = [ prompt_padded_tokens = [
...@@ -504,6 +510,7 @@ class SamplingTensors: ...@@ -504,6 +510,7 @@ class SamplingTensors:
dtype=torch.long, dtype=torch.long,
pin_memory=pin_memory, pin_memory=pin_memory,
) )
if do_penalties:
prompt_tensor = torch.tensor( prompt_tensor = torch.tensor(
prompt_padded_tokens, prompt_padded_tokens,
device="cpu", device="cpu",
...@@ -516,6 +523,9 @@ class SamplingTensors: ...@@ -516,6 +523,9 @@ class SamplingTensors:
dtype=torch.long, dtype=torch.long,
pin_memory=pin_memory, pin_memory=pin_memory,
) )
else:
prompt_tensor = None
output_tensor = None
# need to transpose and make contiguous to # need to transpose and make contiguous to
# copy the tensor correctly. # copy the tensor correctly.
# [batch_size, n_seeds] -> [n_seeds, batch_size] # [batch_size, n_seeds] -> [n_seeds, batch_size]
...@@ -538,6 +548,16 @@ class SamplingTensors: ...@@ -538,6 +548,16 @@ class SamplingTensors:
extra_seeds_gpu = None extra_seeds_gpu = None
sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds] sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds]
if do_penalties:
prompt_tokens_gpu = prompt_tensor.to(device=device,
non_blocking=True)
output_tokens_gpu = output_tensor.to(device=device,
non_blocking=True)
else:
empty_tensor = torch.empty(0, device=device, dtype=torch.long)
prompt_tokens_gpu = empty_tensor
output_tokens_gpu = empty_tensor
return cls( return cls(
temperatures=temperatures_t.to(device=device, non_blocking=True), temperatures=temperatures_t.to(device=device, non_blocking=True),
top_ps=top_ps_t.to(device=device, non_blocking=True), top_ps=top_ps_t.to(device=device, non_blocking=True),
...@@ -549,8 +569,8 @@ class SamplingTensors: ...@@ -549,8 +569,8 @@ class SamplingTensors:
non_blocking=True), non_blocking=True),
repetition_penalties=repetition_penalties_t.to(device=device, repetition_penalties=repetition_penalties_t.to(device=device,
non_blocking=True), non_blocking=True),
prompt_tokens=prompt_tensor.to(device=device, non_blocking=True), prompt_tokens=prompt_tokens_gpu,
output_tokens=output_tensor.to(device=device, non_blocking=True), output_tokens=output_tokens_gpu,
sampling_seeds=sampling_seeds_gpu, sampling_seeds=sampling_seeds_gpu,
sample_indices=sample_indices_t.to(device=device, sample_indices=sample_indices_t.to(device=device,
non_blocking=True), non_blocking=True),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment