Commit 0e4f5361 authored by haileyschoelkopf's avatar haileyschoelkopf Committed by lintangsutawika
Browse files

init tensors on-device where possible

parent bc7f52e6
...@@ -379,14 +379,16 @@ class HFLM(LM): ...@@ -379,14 +379,16 @@ class HFLM(LM):
inp = torch.tensor( inp = torch.tensor(
(context_enc)[-self.max_length :], (context_enc)[-self.max_length :],
dtype=torch.long, dtype=torch.long,
).to(self.device) device=self.device
)
(inplen,) = inp.shape (inplen,) = inp.shape
cont = torch.tensor( cont = torch.tensor(
(continuation_enc)[-self.max_length :], (continuation_enc)[-self.max_length :],
# TODO: left-shift these? # TODO: left-shift these?
# TODO: our code assumes we never end up truncating conts for either model type # TODO: our code assumes we never end up truncating conts for either model type
dtype=torch.long, dtype=torch.long,
).to(self.device) device=self.device,
)
(contlen,) = cont.shape (contlen,) = cont.shape
conts.append(cont) conts.append(cont)
......
...@@ -448,8 +448,10 @@ def pad_and_concat(max_length:int, tensors: List[torch.Tensor], padding_side="ri ...@@ -448,8 +448,10 @@ def pad_and_concat(max_length:int, tensors: List[torch.Tensor], padding_side="ri
tensors[i] = torch.cat( tensors[i] = torch.cat(
[ [
tensor, # [seq] tensor, # [seq]
torch.zeros(max_length - tensor_len, dtype=torch.long).to( torch.zeros(
tensor.device max_length - tensor_len,
dtype=torch.long,
device=tensor.device,
), # [padding_length - seq] ), # [padding_length - seq]
], ],
dim=0, dim=0,
...@@ -458,8 +460,10 @@ def pad_and_concat(max_length:int, tensors: List[torch.Tensor], padding_side="ri ...@@ -458,8 +460,10 @@ def pad_and_concat(max_length:int, tensors: List[torch.Tensor], padding_side="ri
# left-pad # left-pad
tensors[i] = torch.cat( tensors[i] = torch.cat(
[ [
torch.zeros(max_length - tensor_len, dtype=torch.long).to( torch.zeros(
tensor.device max_length - tensor_len,
dtype=torch.long,
device=tensor.device,
), # [padding_length - seq] ), # [padding_length - seq]
tensor, # [seq] tensor, # [seq]
], ],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment