Commit 1f4b27c6 authored by Geewook Kim's avatar Geewook Kim
Browse files

feat: remove bfloat16 for cpu

parent 6f8a40da
...@@ -52,8 +52,6 @@ if __name__ == "__main__": ...@@ -52,8 +52,6 @@ if __name__ == "__main__":
pretrained_model.half() pretrained_model.half()
device = torch.device("cuda") device = torch.device("cuda")
pretrained_model.to(device) pretrained_model.to(device)
else:
pretrained_model.encoder.to(torch.bfloat16)
pretrained_model.eval() pretrained_model.eval()
......
...@@ -443,8 +443,6 @@ class DonutModel(PreTrainedModel): ...@@ -443,8 +443,6 @@ class DonutModel(PreTrainedModel):
if self.device.type == "cuda": # half is not compatible in cpu implementation. if self.device.type == "cuda": # half is not compatible in cpu implementation.
image_tensors = image_tensors.half() image_tensors = image_tensors.half()
image_tensors = image_tensors.to(self.device) image_tensors = image_tensors.to(self.device)
else:
image_tensors = image_tensors.to(torch.bfloat16)
if prompt_tensors is None: if prompt_tensors is None:
prompt_tensors = self.decoder.tokenizer(prompt, add_special_tokens=False, return_tensors="pt")["input_ids"] prompt_tensors = self.decoder.tokenizer(prompt, add_special_tokens=False, return_tensors="pt")["input_ids"]
......
...@@ -24,8 +24,6 @@ def test(args): ...@@ -24,8 +24,6 @@ def test(args):
if torch.cuda.is_available(): if torch.cuda.is_available():
pretrained_model.half() pretrained_model.half()
pretrained_model.to("cuda") pretrained_model.to("cuda")
else:
pretrained_model.encoder.to(torch.bfloat16)
pretrained_model.eval() pretrained_model.eval()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment