Commit 0e4f5361 authored by haileyschoelkopf's avatar haileyschoelkopf Committed by lintangsutawika
Browse files

init tensors on-device where possible

parent bc7f52e6
......@@ -379,14 +379,16 @@ class HFLM(LM):
inp = torch.tensor(
(context_enc)[-self.max_length :],
dtype=torch.long,
).to(self.device)
device=self.device
)
(inplen,) = inp.shape
cont = torch.tensor(
(continuation_enc)[-self.max_length :],
# TODO: left-shift these?
# TODO: our code assumes we never end up truncating conts for either model type
dtype=torch.long,
).to(self.device)
device=self.device,
)
(contlen,) = cont.shape
conts.append(cont)
......
......@@ -448,8 +448,10 @@ def pad_and_concat(max_length:int, tensors: List[torch.Tensor], padding_side="ri
tensors[i] = torch.cat(
[
tensor, # [seq]
torch.zeros(max_length - tensor_len, dtype=torch.long).to(
tensor.device
torch.zeros(
max_length - tensor_len,
dtype=torch.long,
device=tensor.device,
), # [padding_length - seq]
],
dim=0,
......@@ -458,8 +460,10 @@ def pad_and_concat(max_length:int, tensors: List[torch.Tensor], padding_side="ri
# left-pad
tensors[i] = torch.cat(
[
torch.zeros(max_length - tensor_len, dtype=torch.long).to(
tensor.device
torch.zeros(
max_length - tensor_len,
dtype=torch.long,
device=tensor.device,
), # [padding_length - seq]
tensor, # [seq]
],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment