Unverified Commit 686cc667 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

fix(server): Check for device type correctly when determining initial padding (#16)

AFAIK there is no torch device type called "gpu".
parent 611e21cb
...@@ -65,7 +65,7 @@ class CausalLMBatch: ...@@ -65,7 +65,7 @@ class CausalLMBatch:
) )
all_logprobs.append(None) all_logprobs.append(None)
pad_to_multiple_of = 8 if "gpu" in str(device) else None pad_to_multiple_of = 8 if device.type == "cuda" else None
tokenized_inputs = tokenizer( tokenized_inputs = tokenizer(
inputs, inputs,
return_tensors="pt", return_tensors="pt",
......
...@@ -77,7 +77,7 @@ class Seq2SeqLMBatch: ...@@ -77,7 +77,7 @@ class Seq2SeqLMBatch:
decoder_logprobs.append(None) decoder_logprobs.append(None)
# Tokenize batch # Tokenize batch
pad_to_multiple_of = 8 if "gpu" in str(device) else None pad_to_multiple_of = 8 if device.type == "cuda" else None
tokenized_inputs = tokenizer( tokenized_inputs = tokenizer(
inputs, inputs,
return_tensors="pt", return_tensors="pt",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment