Unverified Commit 745f596c authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat(server): use float16 (#304)

parent 68e9d6ab
......@@ -67,7 +67,7 @@ class BLOOMSharded(BLOOM):
self.master = rank == 0
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
dtype = torch.float16
else:
device = torch.device("cpu")
dtype = torch.float32
......
......@@ -452,7 +452,7 @@ class CausalLM(Model):
):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
dtype = torch.float16
else:
if quantize:
raise ValueError("quantization is not available on CPU")
......
......@@ -199,7 +199,7 @@ class GalacticaSharded(Galactica):
self.master = rank == 0
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
dtype = torch.float16
else:
device = torch.device("cpu")
dtype = torch.float32
......
......@@ -38,7 +38,7 @@ class GPTNeoxSharded(CausalLM):
self.master = rank == 0
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
dtype = torch.float16
else:
device = torch.device("cpu")
dtype = torch.float32
......
......@@ -54,7 +54,7 @@ class OPTSharded(OPT):
self.master = rank == 0
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
dtype = torch.float16
else:
device = torch.device("cpu")
dtype = torch.float32
......
......@@ -17,7 +17,7 @@ class SantaCoder(CausalLM):
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
dtype = torch.float16
else:
if quantize:
raise ValueError("quantization is not available on CPU")
......
......@@ -506,7 +506,7 @@ class Seq2SeqLM(Model):
):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
dtype = torch.float16
else:
if quantize:
raise ValueError("quantization is not available on CPU")
......
......@@ -38,7 +38,7 @@ class T5Sharded(Seq2SeqLM):
self.master = rank == 0
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
dtype = torch.float16
else:
device = torch.device("cpu")
dtype = torch.float32
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment