Commit 042180d8 authored by OlivierDehaene's avatar OlivierDehaene
Browse files

fix(server): Only pad to multiple of 8 on GPUs

parent a2985036
...@@ -71,8 +71,9 @@ class CausalLMBatch: ...@@ -71,8 +71,9 @@ class CausalLMBatch:
) )
) )
pad_to_multiple_of = 8 if "gpu" in str(device) else None
tokenized_inputs = tokenizer( tokenized_inputs = tokenizer(
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8 inputs, return_tensors="pt", padding=True, pad_to_multiple_of=pad_to_multiple_of
).to(device) ).to(device)
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1) all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
......
...@@ -83,8 +83,9 @@ class Seq2SeqLMBatch: ...@@ -83,8 +83,9 @@ class Seq2SeqLMBatch:
) )
# Tokenize batch # Tokenize batch
pad_to_multiple_of = 8 if "gpu" in str(device) else None
tokenized_inputs = tokenizer( tokenized_inputs = tokenizer(
inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8 inputs, return_tensors="pt", padding=True, pad_to_multiple_of=pad_to_multiple_of
).to(device) ).to(device)
# Convert decoder_input_ids to torch tensor of size [batch_size, 1] # Convert decoder_input_ids to torch tensor of size [batch_size, 1]
decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1) decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment