Unverified Commit 745f596c authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat(server): use float16 (#304)

parent 68e9d6ab
...@@ -67,7 +67,7 @@ class BLOOMSharded(BLOOM): ...@@ -67,7 +67,7 @@ class BLOOMSharded(BLOOM):
self.master = rank == 0 self.master = rank == 0
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 dtype = torch.float16
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 dtype = torch.float32
......
...@@ -452,7 +452,7 @@ class CausalLM(Model): ...@@ -452,7 +452,7 @@ class CausalLM(Model):
): ):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 dtype = torch.float16
else: else:
if quantize: if quantize:
raise ValueError("quantization is not available on CPU") raise ValueError("quantization is not available on CPU")
......
...@@ -199,7 +199,7 @@ class GalacticaSharded(Galactica): ...@@ -199,7 +199,7 @@ class GalacticaSharded(Galactica):
self.master = rank == 0 self.master = rank == 0
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 dtype = torch.float16
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 dtype = torch.float32
......
...@@ -38,7 +38,7 @@ class GPTNeoxSharded(CausalLM): ...@@ -38,7 +38,7 @@ class GPTNeoxSharded(CausalLM):
self.master = rank == 0 self.master = rank == 0
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 dtype = torch.float16
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 dtype = torch.float32
......
...@@ -54,7 +54,7 @@ class OPTSharded(OPT): ...@@ -54,7 +54,7 @@ class OPTSharded(OPT):
self.master = rank == 0 self.master = rank == 0
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 dtype = torch.float16
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 dtype = torch.float32
......
...@@ -17,7 +17,7 @@ class SantaCoder(CausalLM): ...@@ -17,7 +17,7 @@ class SantaCoder(CausalLM):
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False): def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 dtype = torch.float16
else: else:
if quantize: if quantize:
raise ValueError("quantization is not available on CPU") raise ValueError("quantization is not available on CPU")
......
...@@ -506,7 +506,7 @@ class Seq2SeqLM(Model): ...@@ -506,7 +506,7 @@ class Seq2SeqLM(Model):
): ):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 dtype = torch.float16
else: else:
if quantize: if quantize:
raise ValueError("quantization is not available on CPU") raise ValueError("quantization is not available on CPU")
......
...@@ -38,7 +38,7 @@ class T5Sharded(Seq2SeqLM): ...@@ -38,7 +38,7 @@ class T5Sharded(Seq2SeqLM):
self.master = rank == 0 self.master = rank == 0
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 dtype = torch.float16
else: else:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 dtype = torch.float32
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment