Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
745f596c
Unverified
Commit
745f596c
authored
May 10, 2023
by
OlivierDehaene
Committed by
GitHub
May 10, 2023
Browse files
feat(server): use float16 (#304)
parent
68e9d6ab
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
8 additions
and
8 deletions
+8
-8
server/text_generation_server/models/bloom.py
server/text_generation_server/models/bloom.py
+1
-1
server/text_generation_server/models/causal_lm.py
server/text_generation_server/models/causal_lm.py
+1
-1
server/text_generation_server/models/galactica.py
server/text_generation_server/models/galactica.py
+1
-1
server/text_generation_server/models/gpt_neox.py
server/text_generation_server/models/gpt_neox.py
+1
-1
server/text_generation_server/models/opt.py
server/text_generation_server/models/opt.py
+1
-1
server/text_generation_server/models/santacoder.py
server/text_generation_server/models/santacoder.py
+1
-1
server/text_generation_server/models/seq2seq_lm.py
server/text_generation_server/models/seq2seq_lm.py
+1
-1
server/text_generation_server/models/t5.py
server/text_generation_server/models/t5.py
+1
-1
No files found.
server/text_generation_server/models/bloom.py
View file @
745f596c
...
@@ -67,7 +67,7 @@ class BLOOMSharded(BLOOM):
...
@@ -67,7 +67,7 @@ class BLOOMSharded(BLOOM):
self
.
master
=
rank
==
0
self
.
master
=
rank
==
0
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
dtype
=
torch
.
b
float16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float32
dtype
=
torch
.
float16
else
:
else
:
device
=
torch
.
device
(
"cpu"
)
device
=
torch
.
device
(
"cpu"
)
dtype
=
torch
.
float32
dtype
=
torch
.
float32
...
...
server/text_generation_server/models/causal_lm.py
View file @
745f596c
...
@@ -452,7 +452,7 @@ class CausalLM(Model):
...
@@ -452,7 +452,7 @@ class CausalLM(Model):
):
):
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
b
float16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float32
dtype
=
torch
.
float16
else
:
else
:
if
quantize
:
if
quantize
:
raise
ValueError
(
"quantization is not available on CPU"
)
raise
ValueError
(
"quantization is not available on CPU"
)
...
...
server/text_generation_server/models/galactica.py
View file @
745f596c
...
@@ -199,7 +199,7 @@ class GalacticaSharded(Galactica):
...
@@ -199,7 +199,7 @@ class GalacticaSharded(Galactica):
self
.
master
=
rank
==
0
self
.
master
=
rank
==
0
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
dtype
=
torch
.
b
float16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float32
dtype
=
torch
.
float16
else
:
else
:
device
=
torch
.
device
(
"cpu"
)
device
=
torch
.
device
(
"cpu"
)
dtype
=
torch
.
float32
dtype
=
torch
.
float32
...
...
server/text_generation_server/models/gpt_neox.py
View file @
745f596c
...
@@ -38,7 +38,7 @@ class GPTNeoxSharded(CausalLM):
...
@@ -38,7 +38,7 @@ class GPTNeoxSharded(CausalLM):
self
.
master
=
rank
==
0
self
.
master
=
rank
==
0
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
dtype
=
torch
.
b
float16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float32
dtype
=
torch
.
float16
else
:
else
:
device
=
torch
.
device
(
"cpu"
)
device
=
torch
.
device
(
"cpu"
)
dtype
=
torch
.
float32
dtype
=
torch
.
float32
...
...
server/text_generation_server/models/opt.py
View file @
745f596c
...
@@ -54,7 +54,7 @@ class OPTSharded(OPT):
...
@@ -54,7 +54,7 @@ class OPTSharded(OPT):
self
.
master
=
rank
==
0
self
.
master
=
rank
==
0
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
dtype
=
torch
.
b
float16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float32
dtype
=
torch
.
float16
else
:
else
:
device
=
torch
.
device
(
"cpu"
)
device
=
torch
.
device
(
"cpu"
)
dtype
=
torch
.
float32
dtype
=
torch
.
float32
...
...
server/text_generation_server/models/santacoder.py
View file @
745f596c
...
@@ -17,7 +17,7 @@ class SantaCoder(CausalLM):
...
@@ -17,7 +17,7 @@ class SantaCoder(CausalLM):
def
__init__
(
self
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
=
False
):
def
__init__
(
self
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
=
False
):
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
b
float16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float32
dtype
=
torch
.
float16
else
:
else
:
if
quantize
:
if
quantize
:
raise
ValueError
(
"quantization is not available on CPU"
)
raise
ValueError
(
"quantization is not available on CPU"
)
...
...
server/text_generation_server/models/seq2seq_lm.py
View file @
745f596c
...
@@ -506,7 +506,7 @@ class Seq2SeqLM(Model):
...
@@ -506,7 +506,7 @@ class Seq2SeqLM(Model):
):
):
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
b
float16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float32
dtype
=
torch
.
float16
else
:
else
:
if
quantize
:
if
quantize
:
raise
ValueError
(
"quantization is not available on CPU"
)
raise
ValueError
(
"quantization is not available on CPU"
)
...
...
server/text_generation_server/models/t5.py
View file @
745f596c
...
@@ -38,7 +38,7 @@ class T5Sharded(Seq2SeqLM):
...
@@ -38,7 +38,7 @@ class T5Sharded(Seq2SeqLM):
self
.
master
=
rank
==
0
self
.
master
=
rank
==
0
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
dtype
=
torch
.
b
float16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float32
dtype
=
torch
.
float16
else
:
else
:
device
=
torch
.
device
(
"cpu"
)
device
=
torch
.
device
(
"cpu"
)
dtype
=
torch
.
float32
dtype
=
torch
.
float32
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment