Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
cd5961b5
Unverified
Commit
cd5961b5
authored
Mar 06, 2023
by
OlivierDehaene
Committed by
GitHub
Mar 06, 2023
Browse files
feat: allow local models (#101)
closes #99
parent
9b205d33
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
24 additions
and
13 deletions
+24
-13
launcher/src/main.rs
launcher/src/main.rs
+5
-1
router/src/main.rs
router/src/main.rs
+13
-4
router/src/queue.rs
router/src/queue.rs
+1
-1
server/text_generation/models/__init__.py
server/text_generation/models/__init__.py
+1
-1
server/text_generation/models/bloom.py
server/text_generation/models/bloom.py
+0
-3
server/text_generation/models/galactica.py
server/text_generation/models/galactica.py
+0
-3
server/text_generation/utils/hub.py
server/text_generation/utils/hub.py
+4
-0
No files found.
launcher/src/main.rs
View file @
cd5961b5
...
@@ -110,8 +110,12 @@ fn main() -> ExitCode {
...
@@ -110,8 +110,12 @@ fn main() -> ExitCode {
})
})
.expect
(
"Error setting Ctrl-C handler"
);
.expect
(
"Error setting Ctrl-C handler"
);
// Check if model_id is a local model
let
local_path
=
Path
::
new
(
&
model_id
);
let
is_local_model
=
local_path
.exists
()
&&
local_path
.is_dir
();
// Download weights for sharded models
// Download weights for sharded models
if
weights_cache_override
.is_none
()
&&
num_shard
>
1
{
if
!
is_local_model
&&
weights_cache_override
.is_none
()
&&
num_shard
>
1
{
let
mut
download_argv
=
vec!
[
let
mut
download_argv
=
vec!
[
"text-generation-server"
.to_string
(),
"text-generation-server"
.to_string
(),
"download-weights"
.to_string
(),
"download-weights"
.to_string
(),
...
...
router/src/main.rs
View file @
cd5961b5
...
@@ -8,6 +8,7 @@ use opentelemetry::sdk::Resource;
...
@@ -8,6 +8,7 @@ use opentelemetry::sdk::Resource;
use
opentelemetry
::{
global
,
KeyValue
};
use
opentelemetry
::{
global
,
KeyValue
};
use
opentelemetry_otlp
::
WithExportConfig
;
use
opentelemetry_otlp
::
WithExportConfig
;
use
std
::
net
::{
IpAddr
,
Ipv4Addr
,
SocketAddr
};
use
std
::
net
::{
IpAddr
,
Ipv4Addr
,
SocketAddr
};
use
std
::
path
::
Path
;
use
text_generation_client
::
ShardedClient
;
use
text_generation_client
::
ShardedClient
;
use
text_generation_router
::
server
;
use
text_generation_router
::
server
;
use
tokenizers
::
Tokenizer
;
use
tokenizers
::
Tokenizer
;
...
@@ -83,11 +84,19 @@ fn main() -> Result<(), std::io::Error> {
...
@@ -83,11 +84,19 @@ fn main() -> Result<(), std::io::Error> {
)
)
});
});
//
Download and instantiate tokenizer
//
Tokenizer instance
// This will only be used to validate payloads
// This will only be used to validate payloads
//
let
local_path
=
Path
::
new
(
&
tokenizer_name
);
let
tokenizer
=
if
local_path
.exists
()
&&
local_path
.is_dir
()
&&
local_path
.join
(
"tokenizer.json"
)
.exists
()
{
// Load local tokenizer
Tokenizer
::
from_file
(
local_path
.join
(
"tokenizer.json"
))
.unwrap
()
}
else
{
// Download and instantiate tokenizer
// We need to download it outside of the Tokio runtime
// We need to download it outside of the Tokio runtime
let
tokenizer
=
Tokenizer
::
from_pretrained
(
tokenizer_name
.clone
(),
None
)
.unwrap
();
Tokenizer
::
from_pretrained
(
tokenizer_name
.clone
(),
None
)
.unwrap
()
};
// Launch Tokio runtime
// Launch Tokio runtime
tokio
::
runtime
::
Builder
::
new_multi_thread
()
tokio
::
runtime
::
Builder
::
new_multi_thread
()
...
...
router/src/queue.rs
View file @
cd5961b5
...
@@ -234,7 +234,7 @@ mod tests {
...
@@ -234,7 +234,7 @@ mod tests {
do_sample
:
false
,
do_sample
:
false
,
seed
:
0
,
seed
:
0
,
repetition_penalty
:
0.0
,
repetition_penalty
:
0.0
,
watermark
:
false
watermark
:
false
,
},
},
stopping_parameters
:
StoppingCriteriaParameters
{
stopping_parameters
:
StoppingCriteriaParameters
{
max_new_tokens
:
0
,
max_new_tokens
:
0
,
...
...
server/text_generation/models/__init__.py
View file @
cd5961b5
...
@@ -41,7 +41,7 @@ torch.set_grad_enabled(False)
...
@@ -41,7 +41,7 @@ torch.set_grad_enabled(False)
def
get_model
(
def
get_model
(
model_id
:
str
,
revision
:
Optional
[
str
],
sharded
:
bool
,
quantize
:
bool
model_id
:
str
,
revision
:
Optional
[
str
],
sharded
:
bool
,
quantize
:
bool
)
->
Model
:
)
->
Model
:
if
model_id
.
startswith
(
"facebook/galactica"
)
:
if
"facebook/galactica"
in
model_id
:
if
sharded
:
if
sharded
:
return
GalacticaSharded
(
model_id
,
revision
,
quantize
=
quantize
)
return
GalacticaSharded
(
model_id
,
revision
,
quantize
=
quantize
)
else
:
else
:
...
...
server/text_generation/models/bloom.py
View file @
cd5961b5
...
@@ -58,9 +58,6 @@ class BLOOMSharded(BLOOM):
...
@@ -58,9 +58,6 @@ class BLOOMSharded(BLOOM):
def
__init__
(
def
__init__
(
self
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
bool
=
False
self
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
bool
=
False
):
):
if
not
model_id
.
startswith
(
"bigscience/bloom"
):
raise
ValueError
(
f
"Model
{
model_id
}
is not supported"
)
self
.
process_group
,
self
.
rank
,
self
.
world_size
=
initialize_torch_distributed
()
self
.
process_group
,
self
.
rank
,
self
.
world_size
=
initialize_torch_distributed
()
self
.
master
=
self
.
rank
==
0
self
.
master
=
self
.
rank
==
0
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
...
...
server/text_generation/models/galactica.py
View file @
cd5961b5
...
@@ -164,9 +164,6 @@ class GalacticaSharded(Galactica):
...
@@ -164,9 +164,6 @@ class GalacticaSharded(Galactica):
def
__init__
(
def
__init__
(
self
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
bool
=
False
self
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
bool
=
False
):
):
if
not
model_id
.
startswith
(
"facebook/galactica"
):
raise
ValueError
(
f
"Model
{
model_id
}
is not supported"
)
self
.
process_group
,
self
.
rank
,
self
.
world_size
=
initialize_torch_distributed
()
self
.
process_group
,
self
.
rank
,
self
.
world_size
=
initialize_torch_distributed
()
self
.
master
=
self
.
rank
==
0
self
.
master
=
self
.
rank
==
0
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
...
...
server/text_generation/utils/hub.py
View file @
cd5961b5
...
@@ -80,6 +80,10 @@ def weight_files(
...
@@ -80,6 +80,10 @@ def weight_files(
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
extension
:
str
=
".safetensors"
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
extension
:
str
=
".safetensors"
)
->
List
[
Path
]:
)
->
List
[
Path
]:
"""Get the local files"""
"""Get the local files"""
# Local model
if
Path
(
model_id
).
exists
()
and
Path
(
model_id
).
is_dir
():
return
list
(
Path
(
model_id
).
glob
(
f
"*
{
extension
}
"
))
try
:
try
:
filenames
=
weight_hub_files
(
model_id
,
revision
,
extension
)
filenames
=
weight_hub_files
(
model_id
,
revision
,
extension
)
except
EntryNotFoundError
as
e
:
except
EntryNotFoundError
as
e
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment