Unverified Commit 87c94bc8 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Revert "Update sampling_metadata.py (#21937)" (#22088)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 28b18cc7
...@@ -539,37 +539,37 @@ class SamplingTensors: ...@@ -539,37 +539,37 @@ class SamplingTensors:
temperatures_t = torch.tensor( temperatures_t = torch.tensor(
temperatures, temperatures,
device="cpu", device="cpu",
dtype=torch.float32, dtype=dtype,
pin_memory=pin_memory, pin_memory=pin_memory,
) )
top_ps_t = torch.tensor( top_ps_t = torch.tensor(
top_ps, top_ps,
device="cpu", device="cpu",
dtype=torch.float32, dtype=dtype,
pin_memory=pin_memory, pin_memory=pin_memory,
) )
min_ps_t = torch.tensor( min_ps_t = torch.tensor(
min_ps, min_ps,
device="cpu", device="cpu",
dtype=torch.float32, dtype=dtype,
pin_memory=pin_memory, pin_memory=pin_memory,
) )
presence_penalties_t = torch.tensor( presence_penalties_t = torch.tensor(
presence_penalties, presence_penalties,
device="cpu", device="cpu",
dtype=torch.float32, dtype=dtype,
pin_memory=pin_memory, pin_memory=pin_memory,
) )
frequency_penalties_t = torch.tensor( frequency_penalties_t = torch.tensor(
frequency_penalties, frequency_penalties,
device="cpu", device="cpu",
dtype=torch.float32, dtype=dtype,
pin_memory=pin_memory, pin_memory=pin_memory,
) )
repetition_penalties_t = torch.tensor( repetition_penalties_t = torch.tensor(
repetition_penalties, repetition_penalties,
device="cpu", device="cpu",
dtype=torch.float32, dtype=dtype,
pin_memory=pin_memory, pin_memory=pin_memory,
) )
top_ks_t = torch.tensor( top_ks_t = torch.tensor(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment