Unverified Commit 4f4c9c16 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

fix(server): t5 cannot run in f16 (#356)

Fix #349
parent 91d9beec
...@@ -40,7 +40,7 @@ class T5Sharded(Seq2SeqLM): ...@@ -40,7 +40,7 @@ class T5Sharded(Seq2SeqLM):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 dtype = torch.float32
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment