Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
ecf6dc3a
Unverified
Commit
ecf6dc3a
authored
Jun 30, 2023
by
Nicolas Patry
Committed by
GitHub
Jun 30, 2023
Browse files
feat: Add the option to force another dtype than `f16`. (#513)
parent
3b0c979e
Changes
17
Show whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
130 additions
and
21 deletions
+130
-21
launcher/src/main.rs
launcher/src/main.rs
+32
-0
server/text_generation_server/cli.py
server/text_generation_server/cli.py
+14
-1
server/text_generation_server/models/__init__.py
server/text_generation_server/models/__init__.py
+50
-5
server/text_generation_server/models/bloom.py
server/text_generation_server/models/bloom.py
+2
-1
server/text_generation_server/models/causal_lm.py
server/text_generation_server/models/causal_lm.py
+2
-1
server/text_generation_server/models/flash_llama.py
server/text_generation_server/models/flash_llama.py
+2
-1
server/text_generation_server/models/flash_neox.py
server/text_generation_server/models/flash_neox.py
+2
-1
server/text_generation_server/models/flash_rw.py
server/text_generation_server/models/flash_rw.py
+2
-1
server/text_generation_server/models/flash_santacoder.py
server/text_generation_server/models/flash_santacoder.py
+2
-1
server/text_generation_server/models/galactica.py
server/text_generation_server/models/galactica.py
+2
-1
server/text_generation_server/models/gpt_neox.py
server/text_generation_server/models/gpt_neox.py
+2
-1
server/text_generation_server/models/opt.py
server/text_generation_server/models/opt.py
+2
-1
server/text_generation_server/models/rw.py
server/text_generation_server/models/rw.py
+2
-1
server/text_generation_server/models/santacoder.py
server/text_generation_server/models/santacoder.py
+2
-1
server/text_generation_server/models/seq2seq_lm.py
server/text_generation_server/models/seq2seq_lm.py
+2
-1
server/text_generation_server/models/t5.py
server/text_generation_server/models/t5.py
+2
-1
server/text_generation_server/server.py
server/text_generation_server/server.py
+8
-2
No files found.
launcher/src/main.rs
View file @
ecf6dc3a
...
@@ -36,6 +36,26 @@ impl std::fmt::Display for Quantization {
...
@@ -36,6 +36,26 @@ impl std::fmt::Display for Quantization {
}
}
}
}
#[derive(Clone,
Copy,
Debug,
ValueEnum)]
enum
Dtype
{
Float16
,
BFloat16
,
}
impl
std
::
fmt
::
Display
for
Dtype
{
fn
fmt
(
&
self
,
f
:
&
mut
std
::
fmt
::
Formatter
<
'_
>
)
->
std
::
fmt
::
Result
{
// To keep in track with `server`.
match
self
{
Dtype
::
Float16
=>
{
write!
(
f
,
"float16"
)
}
Dtype
::
BFloat16
=>
{
write!
(
f
,
"bfloat16"
)
}
}
}
}
/// App Configuration
/// App Configuration
#[derive(Parser,
Debug)]
#[derive(Parser,
Debug)]
#[clap(author,
version,
about,
long_about
=
None)]
#[clap(author,
version,
about,
long_about
=
None)]
...
@@ -71,6 +91,10 @@ struct Args {
...
@@ -71,6 +91,10 @@ struct Args {
#[clap(long,
env,
value_enum)]
#[clap(long,
env,
value_enum)]
quantize
:
Option
<
Quantization
>
,
quantize
:
Option
<
Quantization
>
,
/// The dtype to be forced upon the model. This option cannot be used with `--quantize`.
#[clap(long,
env,
value_enum)]
dtype
:
Option
<
Dtype
>
,
/// Whether you want to execute hub modelling code. Explicitly passing a `revision` is
/// Whether you want to execute hub modelling code. Explicitly passing a `revision` is
/// encouraged when loading a model with custom code to ensure no malicious code has been
/// encouraged when loading a model with custom code to ensure no malicious code has been
/// contributed in a newer revision.
/// contributed in a newer revision.
...
@@ -258,6 +282,7 @@ fn shard_manager(
...
@@ -258,6 +282,7 @@ fn shard_manager(
model_id
:
String
,
model_id
:
String
,
revision
:
Option
<
String
>
,
revision
:
Option
<
String
>
,
quantize
:
Option
<
Quantization
>
,
quantize
:
Option
<
Quantization
>
,
dtype
:
Option
<
Dtype
>
,
trust_remote_code
:
bool
,
trust_remote_code
:
bool
,
uds_path
:
String
,
uds_path
:
String
,
rank
:
usize
,
rank
:
usize
,
...
@@ -307,6 +332,11 @@ fn shard_manager(
...
@@ -307,6 +332,11 @@ fn shard_manager(
shard_argv
.push
(
quantize
.to_string
())
shard_argv
.push
(
quantize
.to_string
())
}
}
if
let
Some
(
dtype
)
=
dtype
{
shard_argv
.push
(
"--dtype"
.to_string
());
shard_argv
.push
(
dtype
.to_string
())
}
// Model optional revision
// Model optional revision
if
let
Some
(
revision
)
=
revision
{
if
let
Some
(
revision
)
=
revision
{
shard_argv
.push
(
"--revision"
.to_string
());
shard_argv
.push
(
"--revision"
.to_string
());
...
@@ -743,6 +773,7 @@ fn spawn_shards(
...
@@ -743,6 +773,7 @@ fn spawn_shards(
let
shutdown_sender
=
shutdown_sender
.clone
();
let
shutdown_sender
=
shutdown_sender
.clone
();
let
otlp_endpoint
=
args
.otlp_endpoint
.clone
();
let
otlp_endpoint
=
args
.otlp_endpoint
.clone
();
let
quantize
=
args
.quantize
;
let
quantize
=
args
.quantize
;
let
dtype
=
args
.dtype
;
let
trust_remote_code
=
args
.trust_remote_code
;
let
trust_remote_code
=
args
.trust_remote_code
;
let
master_port
=
args
.master_port
;
let
master_port
=
args
.master_port
;
let
disable_custom_kernels
=
args
.disable_custom_kernels
;
let
disable_custom_kernels
=
args
.disable_custom_kernels
;
...
@@ -753,6 +784,7 @@ fn spawn_shards(
...
@@ -753,6 +784,7 @@ fn spawn_shards(
model_id
,
model_id
,
revision
,
revision
,
quantize
,
quantize
,
dtype
,
trust_remote_code
,
trust_remote_code
,
uds_path
,
uds_path
,
rank
,
rank
,
...
...
server/text_generation_server/cli.py
View file @
ecf6dc3a
...
@@ -16,12 +16,18 @@ class Quantization(str, Enum):
...
@@ -16,12 +16,18 @@ class Quantization(str, Enum):
gptq
=
"gptq"
gptq
=
"gptq"
class
Dtype
(
str
,
Enum
):
float16
=
"float16"
bloat16
=
"bfloat16"
@
app
.
command
()
@
app
.
command
()
def
serve
(
def
serve
(
model_id
:
str
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
revision
:
Optional
[
str
]
=
None
,
sharded
:
bool
=
False
,
sharded
:
bool
=
False
,
quantize
:
Optional
[
Quantization
]
=
None
,
quantize
:
Optional
[
Quantization
]
=
None
,
dtype
:
Optional
[
Dtype
]
=
None
,
trust_remote_code
:
bool
=
False
,
trust_remote_code
:
bool
=
False
,
uds_path
:
Path
=
"/tmp/text-generation-server"
,
uds_path
:
Path
=
"/tmp/text-generation-server"
,
logger_level
:
str
=
"INFO"
,
logger_level
:
str
=
"INFO"
,
...
@@ -64,7 +70,14 @@ def serve(
...
@@ -64,7 +70,14 @@ def serve(
# Downgrade enum into str for easier management later on
# Downgrade enum into str for easier management later on
quantize
=
None
if
quantize
is
None
else
quantize
.
value
quantize
=
None
if
quantize
is
None
else
quantize
.
value
server
.
serve
(
model_id
,
revision
,
sharded
,
quantize
,
trust_remote_code
,
uds_path
)
dtype
=
None
if
dtype
is
None
else
dtype
.
value
if
dtype
is
not
None
and
quantize
is
not
None
:
raise
RuntimeError
(
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
)
server
.
serve
(
model_id
,
revision
,
sharded
,
quantize
,
dtype
,
trust_remote_code
,
uds_path
)
@
app
.
command
()
@
app
.
command
()
...
...
server/text_generation_server/models/__init__.py
View file @
ecf6dc3a
...
@@ -100,11 +100,25 @@ def get_model(
...
@@ -100,11 +100,25 @@ def get_model(
revision
:
Optional
[
str
],
revision
:
Optional
[
str
],
sharded
:
bool
,
sharded
:
bool
,
quantize
:
Optional
[
str
],
quantize
:
Optional
[
str
],
dtype
:
Optional
[
str
],
trust_remote_code
:
bool
,
trust_remote_code
:
bool
,
)
->
Model
:
)
->
Model
:
if
dtype
is
None
:
dtype
=
torch
.
float16
elif
dtype
==
"float16"
:
dtype
=
torch
.
float16
elif
dtype
==
"bfloat16"
:
dtype
=
torch
.
bfloat16
else
:
raise
RuntimeError
(
f
"Unknown dtype
{
dtype
}
"
)
if
"facebook/galactica"
in
model_id
:
if
"facebook/galactica"
in
model_id
:
return
GalacticaSharded
(
return
GalacticaSharded
(
model_id
,
revision
,
quantize
=
quantize
,
trust_remote_code
=
trust_remote_code
model_id
,
revision
,
quantize
=
quantize
,
dtype
=
dtype
,
dtypetrust_remote_code
=
trust_remote_code
,
)
)
if
model_id
.
startswith
(
"bigcode/"
):
if
model_id
.
startswith
(
"bigcode/"
):
...
@@ -113,6 +127,7 @@ def get_model(
...
@@ -113,6 +127,7 @@ def get_model(
model_id
,
model_id
,
revision
,
revision
,
quantize
=
quantize
,
quantize
=
quantize
,
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
)
)
elif
sharded
:
elif
sharded
:
...
@@ -124,6 +139,7 @@ def get_model(
...
@@ -124,6 +139,7 @@ def get_model(
model_id
,
model_id
,
revision
,
revision
,
quantize
=
quantize
,
quantize
=
quantize
,
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
)
)
...
@@ -138,6 +154,7 @@ def get_model(
...
@@ -138,6 +154,7 @@ def get_model(
model_id
,
model_id
,
revision
,
revision
,
quantize
=
quantize
,
quantize
=
quantize
,
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
)
)
elif
sharded
:
elif
sharded
:
...
@@ -149,12 +166,17 @@ def get_model(
...
@@ -149,12 +166,17 @@ def get_model(
model_id
,
model_id
,
revision
,
revision
,
quantize
=
quantize
,
quantize
=
quantize
,
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
)
)
if
model_type
==
"bloom"
:
if
model_type
==
"bloom"
:
return
BLOOMSharded
(
return
BLOOMSharded
(
model_id
,
revision
,
quantize
=
quantize
,
trust_remote_code
=
trust_remote_code
model_id
,
revision
,
quantize
=
quantize
,
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
)
)
elif
model_type
==
"gpt_neox"
:
elif
model_type
==
"gpt_neox"
:
...
@@ -163,6 +185,7 @@ def get_model(
...
@@ -163,6 +185,7 @@ def get_model(
model_id
,
model_id
,
revision
,
revision
,
quantize
=
quantize
,
quantize
=
quantize
,
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
)
)
elif
sharded
:
elif
sharded
:
...
@@ -170,6 +193,7 @@ def get_model(
...
@@ -170,6 +193,7 @@ def get_model(
model_id
,
model_id
,
revision
,
revision
,
quantize
=
quantize
,
quantize
=
quantize
,
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
)
)
else
:
else
:
...
@@ -177,6 +201,7 @@ def get_model(
...
@@ -177,6 +201,7 @@ def get_model(
model_id
,
model_id
,
revision
,
revision
,
quantize
=
quantize
,
quantize
=
quantize
,
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
)
)
...
@@ -186,6 +211,7 @@ def get_model(
...
@@ -186,6 +211,7 @@ def get_model(
model_id
,
model_id
,
revision
,
revision
,
quantize
=
quantize
,
quantize
=
quantize
,
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
)
)
elif
sharded
:
elif
sharded
:
...
@@ -195,6 +221,7 @@ def get_model(
...
@@ -195,6 +221,7 @@ def get_model(
model_id
,
model_id
,
revision
,
revision
,
quantize
=
quantize
,
quantize
=
quantize
,
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
)
)
...
@@ -210,6 +237,7 @@ def get_model(
...
@@ -210,6 +237,7 @@ def get_model(
model_id
,
model_id
,
revision
,
revision
,
quantize
=
quantize
,
quantize
=
quantize
,
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
)
)
raise
NotImplementedError
(
raise
NotImplementedError
(
...
@@ -221,6 +249,7 @@ def get_model(
...
@@ -221,6 +249,7 @@ def get_model(
model_id
,
model_id
,
revision
,
revision
,
quantize
=
quantize
,
quantize
=
quantize
,
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
)
)
else
:
else
:
...
@@ -228,12 +257,17 @@ def get_model(
...
@@ -228,12 +257,17 @@ def get_model(
model_id
,
model_id
,
revision
,
revision
,
quantize
=
quantize
,
quantize
=
quantize
,
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
)
)
elif
model_type
==
"opt"
:
elif
model_type
==
"opt"
:
return
OPTSharded
(
return
OPTSharded
(
model_id
,
revision
,
quantize
=
quantize
,
trust_remote_code
=
trust_remote_code
model_id
,
revision
,
quantize
=
quantize
,
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
)
)
elif
model_type
==
"t5"
:
elif
model_type
==
"t5"
:
...
@@ -241,6 +275,7 @@ def get_model(
...
@@ -241,6 +275,7 @@ def get_model(
model_id
,
model_id
,
revision
,
revision
,
quantize
=
quantize
,
quantize
=
quantize
,
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
)
)
...
@@ -253,11 +288,19 @@ def get_model(
...
@@ -253,11 +288,19 @@ def get_model(
if
model_type
in
modeling_auto
.
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
:
if
model_type
in
modeling_auto
.
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
:
return
CausalLM
(
return
CausalLM
(
model_id
,
revision
,
quantize
=
quantize
,
trust_remote_code
=
trust_remote_code
model_id
,
revision
,
quantize
=
quantize
,
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
)
)
if
model_type
in
modeling_auto
.
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
:
if
model_type
in
modeling_auto
.
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
:
return
Seq2SeqLM
(
return
Seq2SeqLM
(
model_id
,
revision
,
quantize
=
quantize
,
trust_remote_code
=
trust_remote_code
model_id
,
revision
,
quantize
=
quantize
,
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
)
)
auto_map
=
config_dict
.
get
(
"auto_map"
,
None
)
auto_map
=
config_dict
.
get
(
"auto_map"
,
None
)
...
@@ -267,6 +310,7 @@ def get_model(
...
@@ -267,6 +310,7 @@ def get_model(
model_id
,
model_id
,
revision
,
revision
,
quantize
=
quantize
,
quantize
=
quantize
,
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
)
)
if
"AutoModelForSeq2SeqLM"
in
auto_map
.
keys
():
if
"AutoModelForSeq2SeqLM"
in
auto_map
.
keys
():
...
@@ -274,6 +318,7 @@ def get_model(
...
@@ -274,6 +318,7 @@ def get_model(
model_id
,
model_id
,
revision
,
revision
,
quantize
=
quantize
,
quantize
=
quantize
,
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
)
)
...
...
server/text_generation_server/models/bloom.py
View file @
ecf6dc3a
...
@@ -42,12 +42,13 @@ class BLOOMSharded(CausalLM):
...
@@ -42,12 +42,13 @@ class BLOOMSharded(CausalLM):
model_id
:
str
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
Optional
[
str
]
=
None
,
quantize
:
Optional
[
str
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
trust_remote_code
:
bool
=
False
,
trust_remote_code
:
bool
=
False
,
):
):
self
.
process_group
,
rank
,
world_size
=
initialize_torch_distributed
()
self
.
process_group
,
rank
,
world_size
=
initialize_torch_distributed
()
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
dtype
=
torch
.
float16
dtype
=
torch
.
float16
if
dtype
is
None
else
dtype
else
:
else
:
device
=
torch
.
device
(
"cpu"
)
device
=
torch
.
device
(
"cpu"
)
dtype
=
torch
.
float32
dtype
=
torch
.
float32
...
...
server/text_generation_server/models/causal_lm.py
View file @
ecf6dc3a
...
@@ -454,11 +454,12 @@ class CausalLM(Model):
...
@@ -454,11 +454,12 @@ class CausalLM(Model):
model_id
:
str
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
Optional
[
str
]
=
None
,
quantize
:
Optional
[
str
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
trust_remote_code
:
bool
=
False
,
trust_remote_code
:
bool
=
False
,
):
):
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
float16
dtype
=
torch
.
float16
if
dtype
is
None
else
dtype
else
:
else
:
if
quantize
:
if
quantize
:
raise
ValueError
(
"quantization is not available on CPU"
)
raise
ValueError
(
"quantization is not available on CPU"
)
...
...
server/text_generation_server/models/flash_llama.py
View file @
ecf6dc3a
...
@@ -25,12 +25,13 @@ class FlashLlama(FlashCausalLM):
...
@@ -25,12 +25,13 @@ class FlashLlama(FlashCausalLM):
model_id
:
str
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
Optional
[
str
]
=
None
,
quantize
:
Optional
[
str
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
trust_remote_code
:
bool
=
False
,
trust_remote_code
:
bool
=
False
,
):
):
self
.
process_group
,
rank
,
world_size
=
initialize_torch_distributed
()
self
.
process_group
,
rank
,
world_size
=
initialize_torch_distributed
()
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
dtype
=
torch
.
float16
dtype
=
torch
.
float16
if
dtype
is
None
else
dtype
else
:
else
:
raise
NotImplementedError
(
"FlashLlama is only available on GPU"
)
raise
NotImplementedError
(
"FlashLlama is only available on GPU"
)
...
...
server/text_generation_server/models/flash_neox.py
View file @
ecf6dc3a
...
@@ -24,12 +24,13 @@ class FlashNeoXSharded(FlashCausalLM):
...
@@ -24,12 +24,13 @@ class FlashNeoXSharded(FlashCausalLM):
model_id
:
str
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
Optional
[
str
]
=
None
,
quantize
:
Optional
[
str
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
trust_remote_code
:
bool
=
False
,
trust_remote_code
:
bool
=
False
,
):
):
self
.
process_group
,
rank
,
world_size
=
initialize_torch_distributed
()
self
.
process_group
,
rank
,
world_size
=
initialize_torch_distributed
()
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
dtype
=
torch
.
float16
dtype
=
torch
.
float16
if
dtype
is
None
else
dtype
else
:
else
:
raise
NotImplementedError
(
"FlashNeoX is only available on GPU"
)
raise
NotImplementedError
(
"FlashNeoX is only available on GPU"
)
...
...
server/text_generation_server/models/flash_rw.py
View file @
ecf6dc3a
...
@@ -25,12 +25,13 @@ class FlashRWSharded(FlashCausalLM):
...
@@ -25,12 +25,13 @@ class FlashRWSharded(FlashCausalLM):
model_id
:
str
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
Optional
[
str
]
=
None
,
quantize
:
Optional
[
str
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
trust_remote_code
:
bool
=
False
,
trust_remote_code
:
bool
=
False
,
):
):
self
.
process_group
,
rank
,
world_size
=
initialize_torch_distributed
()
self
.
process_group
,
rank
,
world_size
=
initialize_torch_distributed
()
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
dtype
=
torch
.
float16
dtype
=
torch
.
float16
if
dtype
is
None
else
dtype
else
:
else
:
raise
NotImplementedError
(
"FlashRW is only available on GPU"
)
raise
NotImplementedError
(
"FlashRW is only available on GPU"
)
...
...
server/text_generation_server/models/flash_santacoder.py
View file @
ecf6dc3a
...
@@ -24,12 +24,13 @@ class FlashSantacoderSharded(FlashCausalLM):
...
@@ -24,12 +24,13 @@ class FlashSantacoderSharded(FlashCausalLM):
model_id
:
str
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
Optional
[
str
]
=
None
,
quantize
:
Optional
[
str
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
trust_remote_code
:
bool
=
False
,
trust_remote_code
:
bool
=
False
,
):
):
self
.
process_group
,
rank
,
world_size
=
initialize_torch_distributed
()
self
.
process_group
,
rank
,
world_size
=
initialize_torch_distributed
()
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
dtype
=
torch
.
float16
dtype
=
torch
.
float16
if
dtype
is
None
else
dtype
else
:
else
:
raise
NotImplementedError
(
"FlashSantacoderSharded is only available on GPU"
)
raise
NotImplementedError
(
"FlashSantacoderSharded is only available on GPU"
)
...
...
server/text_generation_server/models/galactica.py
View file @
ecf6dc3a
...
@@ -158,12 +158,13 @@ class GalacticaSharded(CausalLM):
...
@@ -158,12 +158,13 @@ class GalacticaSharded(CausalLM):
model_id
:
str
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
Optional
[
str
]
=
None
,
quantize
:
Optional
[
str
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
trust_remote_code
:
bool
=
False
,
trust_remote_code
:
bool
=
False
,
):
):
self
.
process_group
,
rank
,
world_size
=
initialize_torch_distributed
()
self
.
process_group
,
rank
,
world_size
=
initialize_torch_distributed
()
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
dtype
=
torch
.
float16
dtype
=
torch
.
float16
if
dtype
is
None
else
dtype
else
:
else
:
device
=
torch
.
device
(
"cpu"
)
device
=
torch
.
device
(
"cpu"
)
dtype
=
torch
.
float32
dtype
=
torch
.
float32
...
...
server/text_generation_server/models/gpt_neox.py
View file @
ecf6dc3a
...
@@ -24,12 +24,13 @@ class GPTNeoxSharded(CausalLM):
...
@@ -24,12 +24,13 @@ class GPTNeoxSharded(CausalLM):
model_id
:
str
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
Optional
[
str
]
=
None
,
quantize
:
Optional
[
str
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
trust_remote_code
:
bool
=
False
,
trust_remote_code
:
bool
=
False
,
):
):
self
.
process_group
,
rank
,
world_size
=
initialize_torch_distributed
()
self
.
process_group
,
rank
,
world_size
=
initialize_torch_distributed
()
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
dtype
=
torch
.
float16
dtype
=
torch
.
float16
if
dtype
is
None
else
dtype
else
:
else
:
device
=
torch
.
device
(
"cpu"
)
device
=
torch
.
device
(
"cpu"
)
dtype
=
torch
.
float32
dtype
=
torch
.
float32
...
...
server/text_generation_server/models/opt.py
View file @
ecf6dc3a
...
@@ -22,12 +22,13 @@ class OPTSharded(CausalLM):
...
@@ -22,12 +22,13 @@ class OPTSharded(CausalLM):
model_id
:
str
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
Optional
[
str
]
=
None
,
quantize
:
Optional
[
str
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
trust_remote_code
:
bool
=
False
,
trust_remote_code
:
bool
=
False
,
):
):
self
.
process_group
,
rank
,
world_size
=
initialize_torch_distributed
()
self
.
process_group
,
rank
,
world_size
=
initialize_torch_distributed
()
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
dtype
=
torch
.
float16
dtype
=
torch
.
float16
if
dtype
is
None
else
dtype
else
:
else
:
device
=
torch
.
device
(
"cpu"
)
device
=
torch
.
device
(
"cpu"
)
dtype
=
torch
.
float32
dtype
=
torch
.
float32
...
...
server/text_generation_server/models/rw.py
View file @
ecf6dc3a
...
@@ -12,11 +12,12 @@ class RW(CausalLM):
...
@@ -12,11 +12,12 @@ class RW(CausalLM):
model_id
:
str
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
Optional
[
str
]
=
None
,
quantize
:
Optional
[
str
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
trust_remote_code
:
bool
=
False
,
trust_remote_code
:
bool
=
False
,
):
):
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
float16
dtype
=
torch
.
float16
if
dtype
is
None
else
dtype
else
:
else
:
if
quantize
:
if
quantize
:
raise
ValueError
(
"quantization is not available on CPU"
)
raise
ValueError
(
"quantization is not available on CPU"
)
...
...
server/text_generation_server/models/santacoder.py
View file @
ecf6dc3a
...
@@ -19,11 +19,12 @@ class SantaCoder(CausalLM):
...
@@ -19,11 +19,12 @@ class SantaCoder(CausalLM):
model_id
:
str
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
Optional
[
str
]
=
None
,
quantize
:
Optional
[
str
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
trust_remote_code
:
bool
=
False
,
trust_remote_code
:
bool
=
False
,
):
):
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
float16
dtype
=
torch
.
float16
if
dtype
is
None
else
dtype
else
:
else
:
if
quantize
:
if
quantize
:
raise
ValueError
(
"quantization is not available on CPU"
)
raise
ValueError
(
"quantization is not available on CPU"
)
...
...
server/text_generation_server/models/seq2seq_lm.py
View file @
ecf6dc3a
...
@@ -504,11 +504,12 @@ class Seq2SeqLM(Model):
...
@@ -504,11 +504,12 @@ class Seq2SeqLM(Model):
model_id
:
str
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
Optional
[
str
]
=
None
,
quantize
:
Optional
[
str
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
trust_remote_code
:
bool
=
False
,
trust_remote_code
:
bool
=
False
,
):
):
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
float16
dtype
=
torch
.
float16
if
dtype
is
None
else
dtype
else
:
else
:
if
quantize
:
if
quantize
:
raise
ValueError
(
"quantization is not available on CPU"
)
raise
ValueError
(
"quantization is not available on CPU"
)
...
...
server/text_generation_server/models/t5.py
View file @
ecf6dc3a
...
@@ -25,12 +25,13 @@ class T5Sharded(Seq2SeqLM):
...
@@ -25,12 +25,13 @@ class T5Sharded(Seq2SeqLM):
model_id
:
str
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
Optional
[
str
]
=
None
,
quantize
:
Optional
[
str
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
trust_remote_code
:
bool
=
False
,
trust_remote_code
:
bool
=
False
,
):
):
self
.
process_group
,
rank
,
world_size
=
initialize_torch_distributed
()
self
.
process_group
,
rank
,
world_size
=
initialize_torch_distributed
()
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
dtype
=
torch
.
float16
dtype
=
torch
.
float16
if
dtype
is
None
else
dtype
else
:
else
:
device
=
torch
.
device
(
"cpu"
)
device
=
torch
.
device
(
"cpu"
)
dtype
=
torch
.
float32
dtype
=
torch
.
float32
...
...
server/text_generation_server/server.py
View file @
ecf6dc3a
...
@@ -106,6 +106,7 @@ def serve(
...
@@ -106,6 +106,7 @@ def serve(
revision
:
Optional
[
str
],
revision
:
Optional
[
str
],
sharded
:
bool
,
sharded
:
bool
,
quantize
:
Optional
[
str
],
quantize
:
Optional
[
str
],
dtype
:
Optional
[
str
],
trust_remote_code
:
bool
,
trust_remote_code
:
bool
,
uds_path
:
Path
,
uds_path
:
Path
,
):
):
...
@@ -114,6 +115,7 @@ def serve(
...
@@ -114,6 +115,7 @@ def serve(
revision
:
Optional
[
str
],
revision
:
Optional
[
str
],
sharded
:
bool
=
False
,
sharded
:
bool
=
False
,
quantize
:
Optional
[
str
]
=
None
,
quantize
:
Optional
[
str
]
=
None
,
dtype
:
Optional
[
str
]
=
None
,
trust_remote_code
:
bool
=
False
,
trust_remote_code
:
bool
=
False
,
):
):
unix_socket_template
=
"unix://{}-{}"
unix_socket_template
=
"unix://{}-{}"
...
@@ -128,7 +130,9 @@ def serve(
...
@@ -128,7 +130,9 @@ def serve(
server_urls
=
[
local_url
]
server_urls
=
[
local_url
]
try
:
try
:
model
=
get_model
(
model_id
,
revision
,
sharded
,
quantize
,
trust_remote_code
)
model
=
get_model
(
model_id
,
revision
,
sharded
,
quantize
,
dtype
,
trust_remote_code
)
except
Exception
:
except
Exception
:
logger
.
exception
(
"Error when initializing model"
)
logger
.
exception
(
"Error when initializing model"
)
raise
raise
...
@@ -159,4 +163,6 @@ def serve(
...
@@ -159,4 +163,6 @@ def serve(
logger
.
info
(
"Signal received. Shutting down"
)
logger
.
info
(
"Signal received. Shutting down"
)
await
server
.
stop
(
0
)
await
server
.
stop
(
0
)
asyncio
.
run
(
serve_inner
(
model_id
,
revision
,
sharded
,
quantize
,
trust_remote_code
))
asyncio
.
run
(
serve_inner
(
model_id
,
revision
,
sharded
,
quantize
,
dtype
,
trust_remote_code
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment