Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
85aa7e2e
Unverified
Commit
85aa7e2e
authored
May 03, 2023
by
OlivierDehaene
Committed by
GitHub
May 03, 2023
Browse files
feat(server): support hf endpoint weight layout (#266)
parent
4096000e
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
146 additions
and
26 deletions
+146
-26
launcher/src/main.rs
launcher/src/main.rs
+15
-9
server/text_generation_server/cli.py
server/text_generation_server/cli.py
+40
-15
server/text_generation_server/models/bloom.py
server/text_generation_server/models/bloom.py
+9
-0
server/text_generation_server/models/flash_llama.py
server/text_generation_server/models/flash_llama.py
+19
-0
server/text_generation_server/models/flash_neox.py
server/text_generation_server/models/flash_neox.py
+10
-0
server/text_generation_server/models/flash_santacoder.py
server/text_generation_server/models/flash_santacoder.py
+10
-0
server/text_generation_server/models/galactica.py
server/text_generation_server/models/galactica.py
+9
-0
server/text_generation_server/models/gpt_neox.py
server/text_generation_server/models/gpt_neox.py
+9
-0
server/text_generation_server/models/opt.py
server/text_generation_server/models/opt.py
+9
-0
server/text_generation_server/models/t5.py
server/text_generation_server/models/t5.py
+9
-0
server/text_generation_server/utils/hub.py
server/text_generation_server/utils/hub.py
+7
-2
No files found.
launcher/src/main.rs
View file @
85aa7e2e
...
@@ -512,7 +512,11 @@ enum LauncherError {
...
@@ -512,7 +512,11 @@ enum LauncherError {
WebserverCannotStart
,
WebserverCannotStart
,
}
}
fn
download_model
(
args
:
&
Args
,
running
:
Arc
<
AtomicBool
>
)
->
Result
<
(),
LauncherError
>
{
fn
download_convert_model
(
args
:
&
Args
,
auto_convert
:
bool
,
running
:
Arc
<
AtomicBool
>
,
)
->
Result
<
(),
LauncherError
>
{
let
mut
download_argv
=
vec!
[
let
mut
download_argv
=
vec!
[
"text-generation-server"
.to_string
(),
"text-generation-server"
.to_string
(),
"download-weights"
.to_string
(),
"download-weights"
.to_string
(),
...
@@ -524,6 +528,11 @@ fn download_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherE
...
@@ -524,6 +528,11 @@ fn download_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherE
"--json-output"
.to_string
(),
"--json-output"
.to_string
(),
];
];
// Auto convert weights to safetensors
if
auto_convert
{
download_argv
.push
(
"--auto-convert"
.to_string
());
}
// Model optional revision
// Model optional revision
if
let
Some
(
revision
)
=
&
args
.revision
{
if
let
Some
(
revision
)
=
&
args
.revision
{
download_argv
.push
(
"--revision"
.to_string
());
download_argv
.push
(
"--revision"
.to_string
());
...
@@ -855,14 +864,11 @@ fn main() -> Result<(), LauncherError> {
...
@@ -855,14 +864,11 @@ fn main() -> Result<(), LauncherError> {
})
})
.expect
(
"Error setting Ctrl-C handler"
);
.expect
(
"Error setting Ctrl-C handler"
);
// Check if model_id is a local model
// auto_convert is only needed for sharded models as we do not require safetensors in
let
local_path
=
Path
::
new
(
&
args
.model_id
);
// single shard mode
let
is_local_model
=
local_path
.exists
()
&&
local_path
.is_dir
();
let
auto_convert
=
num_shard
>
1
;
// Download and convert model weights
// Download weights for sharded models
download_convert_model
(
&
args
,
auto_convert
,
running
.clone
())
?
;
if
!
is_local_model
&&
args
.weights_cache_override
.is_none
()
&&
num_shard
>
1
{
download_model
(
&
args
,
running
.clone
())
?
;
}
// Shared shutdown bool
// Shared shutdown bool
let
shutdown
=
Arc
::
new
(
Mutex
::
new
(
false
));
let
shutdown
=
Arc
::
new
(
Mutex
::
new
(
false
));
...
...
server/text_generation_server/cli.py
View file @
85aa7e2e
...
@@ -63,6 +63,7 @@ def download_weights(
...
@@ -63,6 +63,7 @@ def download_weights(
model_id
:
str
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
revision
:
Optional
[
str
]
=
None
,
extension
:
str
=
".safetensors"
,
extension
:
str
=
".safetensors"
,
auto_convert
:
bool
=
True
,
logger_level
:
str
=
"INFO"
,
logger_level
:
str
=
"INFO"
,
json_output
:
bool
=
False
,
json_output
:
bool
=
False
,
):
):
...
@@ -84,31 +85,55 @@ def download_weights(
...
@@ -84,31 +85,55 @@ def download_weights(
# Test if files were already download
# Test if files were already download
try
:
try
:
utils
.
weight_files
(
model_id
,
revision
,
extension
)
utils
.
weight_files
(
model_id
,
revision
,
extension
)
logger
.
info
(
logger
.
info
(
"Files are already present on the host. "
"Skipping download."
)
"Files are already present in the local cache. "
"Skipping download."
)
return
return
# Local files not found
# Local files not found
except
utils
.
LocalEntryNotFoundError
:
except
(
utils
.
LocalEntryNotFoundError
,
FileNotFoundError
)
:
pass
pass
# Download weights directly
is_local_model
=
(
Path
(
model_id
).
exists
()
and
Path
(
model_id
).
is_dir
())
or
os
.
getenv
(
"WEIGHTS_CACHE_OVERRIDE"
,
None
)
is
not
None
if
not
is_local_model
:
# Try to download weights from the hub
try
:
filenames
=
utils
.
weight_hub_files
(
model_id
,
revision
,
extension
)
utils
.
download_weights
(
filenames
,
model_id
,
revision
)
# Successfully downloaded weights
return
# No weights found on the hub with this extension
except
utils
.
EntryNotFoundError
as
e
:
# Check if we want to automatically convert to safetensors or if we can use .bin weights instead
if
not
extension
==
".safetensors"
or
not
auto_convert
:
raise
e
# Try to see if there are local pytorch weights
try
:
try
:
filenames
=
utils
.
weight_hub_files
(
model_id
,
revision
,
extension
)
# Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
utils
.
download_weights
(
filenames
,
model_id
,
revision
)
local_pt_files
=
utils
.
weight_files
(
model_id
,
revision
,
".bin"
)
except
utils
.
EntryNotFoundError
as
e
:
if
not
extension
==
".safetensors"
:
raise
e
logger
.
warning
(
# No local pytorch weights
f
"No safetensors weights found for model
{
model_id
}
at revision
{
revision
}
. "
except
utils
.
LocalEntryNotFoundError
:
f
"Converting PyTorch weights instead."
if
extension
==
".safetensors"
:
)
logger
.
warning
(
f
"No safetensors weights found for model
{
model_id
}
at revision
{
revision
}
. "
f
"Downloading PyTorch weights."
)
# Try to see if there are pytorch weights
# Try to see if there are pytorch weights
on the hub
pt_filenames
=
utils
.
weight_hub_files
(
model_id
,
revision
,
".bin"
)
pt_filenames
=
utils
.
weight_hub_files
(
model_id
,
revision
,
".bin"
)
# Download pytorch weights
# Download pytorch weights
local_pt_files
=
utils
.
download_weights
(
pt_filenames
,
model_id
,
revision
)
local_pt_files
=
utils
.
download_weights
(
pt_filenames
,
model_id
,
revision
)
if
auto_convert
:
logger
.
warning
(
f
"No safetensors weights found for model
{
model_id
}
at revision
{
revision
}
. "
f
"Converting PyTorch weights to safetensors."
)
# Safetensors final filenames
local_st_files
=
[
local_st_files
=
[
p
.
parent
/
f
"
{
p
.
stem
.
lstrip
(
'pytorch_'
)
}
.safetensors"
p
.
parent
/
f
"
{
p
.
stem
.
lstrip
(
'pytorch_'
)
}
.safetensors"
for
p
in
local_pt_files
for
p
in
local_pt_files
...
...
server/text_generation_server/models/bloom.py
View file @
85aa7e2e
...
@@ -223,6 +223,15 @@ class BLOOMSharded(BLOOM):
...
@@ -223,6 +223,15 @@ class BLOOMSharded(BLOOM):
if
name
==
"word_embeddings.weight"
:
if
name
==
"word_embeddings.weight"
:
model
.
lm_head
.
_parameters
[
"weight"
]
=
tensor
model
.
lm_head
.
_parameters
[
"weight"
]
=
tensor
uninitialized_parameters
=
[]
for
n
,
p
in
model
.
named_parameters
():
if
p
.
data
.
device
==
torch
.
device
(
"meta"
):
uninitialized_parameters
.
append
(
n
)
if
uninitialized_parameters
:
raise
RuntimeError
(
f
"found uninitialized parameters in model:
{
uninitialized_parameters
}
"
)
def
forward
(
def
forward
(
self
,
input_ids
,
attention_mask
,
position_ids
,
past_key_values
:
Optional
=
None
self
,
input_ids
,
attention_mask
,
position_ids
,
past_key_values
:
Optional
=
None
):
):
...
...
server/text_generation_server/models/flash_llama.py
View file @
85aa7e2e
...
@@ -139,6 +139,15 @@ class FlashLlama(FlashCausalLM):
...
@@ -139,6 +139,15 @@ class FlashLlama(FlashCausalLM):
del
value
del
value
uninitialized_parameters
=
[]
for
n
,
p
in
model
.
named_parameters
():
if
p
.
data
.
device
==
torch
.
device
(
"meta"
):
uninitialized_parameters
.
append
(
n
)
if
uninitialized_parameters
:
raise
RuntimeError
(
f
"found uninitialized parameters in model:
{
uninitialized_parameters
}
"
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
model
.
post_load_weights
(
quantize
)
model
.
post_load_weights
(
quantize
)
...
@@ -300,5 +309,15 @@ class FlashLlamaSharded(FlashLlama):
...
@@ -300,5 +309,15 @@ class FlashLlamaSharded(FlashLlama):
else
:
else
:
module
.
_buffers
[
param_name
]
=
tensor
module
.
_buffers
[
param_name
]
=
tensor
uninitialized_parameters
=
[]
for
n
,
p
in
model
.
named_parameters
():
if
p
.
data
.
device
==
torch
.
device
(
"meta"
):
uninitialized_parameters
.
append
(
n
)
if
uninitialized_parameters
:
raise
RuntimeError
(
f
"found uninitialized parameters in model:
{
uninitialized_parameters
}
"
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
model
.
post_load_weights
(
quantize
)
model
.
post_load_weights
(
quantize
)
server/text_generation_server/models/flash_neox.py
View file @
85aa7e2e
...
@@ -149,4 +149,14 @@ class FlashNeoXSharded(FlashNeoX):
...
@@ -149,4 +149,14 @@ class FlashNeoXSharded(FlashNeoX):
module
.
_parameters
[
param_name
]
=
tensor
module
.
_parameters
[
param_name
]
=
tensor
else
:
else
:
module
.
_buffers
[
param_name
]
=
tensor
module
.
_buffers
[
param_name
]
=
tensor
uninitialized_parameters
=
[]
for
n
,
p
in
model
.
named_parameters
():
if
p
.
data
.
device
==
torch
.
device
(
"meta"
):
uninitialized_parameters
.
append
(
n
)
if
uninitialized_parameters
:
raise
RuntimeError
(
f
"found uninitialized parameters in model:
{
uninitialized_parameters
}
"
)
model
.
post_load_weights
(
quantize
)
model
.
post_load_weights
(
quantize
)
server/text_generation_server/models/flash_santacoder.py
View file @
85aa7e2e
...
@@ -372,5 +372,15 @@ class FlashSantacoderSharded(FlashSantacoder):
...
@@ -372,5 +372,15 @@ class FlashSantacoderSharded(FlashSantacoder):
module
.
_parameters
[
param_name
]
=
tensor
module
.
_parameters
[
param_name
]
=
tensor
else
:
else
:
module
.
_buffers
[
param_name
]
=
tensor
module
.
_buffers
[
param_name
]
=
tensor
uninitialized_parameters
=
[]
for
n
,
p
in
model
.
named_parameters
():
if
p
.
data
.
device
==
torch
.
device
(
"meta"
):
uninitialized_parameters
.
append
(
n
)
if
uninitialized_parameters
:
raise
RuntimeError
(
f
"found uninitialized parameters in model:
{
uninitialized_parameters
}
"
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
model
.
post_load_weights
(
quantize
)
model
.
post_load_weights
(
quantize
)
server/text_generation_server/models/galactica.py
View file @
85aa7e2e
...
@@ -355,6 +355,15 @@ class GalacticaSharded(Galactica):
...
@@ -355,6 +355,15 @@ class GalacticaSharded(Galactica):
if
name
==
"model.decoder.embed_tokens.weight"
:
if
name
==
"model.decoder.embed_tokens.weight"
:
model
.
lm_head
.
_parameters
[
"weight"
]
=
tensor
model
.
lm_head
.
_parameters
[
"weight"
]
=
tensor
uninitialized_parameters
=
[]
for
n
,
p
in
model
.
named_parameters
():
if
p
.
data
.
device
==
torch
.
device
(
"meta"
):
uninitialized_parameters
.
append
(
n
)
if
uninitialized_parameters
:
raise
RuntimeError
(
f
"found uninitialized parameters in model:
{
uninitialized_parameters
}
"
)
def
forward
(
def
forward
(
self
,
input_ids
,
attention_mask
,
position_ids
,
past_key_values
:
Optional
=
None
self
,
input_ids
,
attention_mask
,
position_ids
,
past_key_values
:
Optional
=
None
):
):
...
...
server/text_generation_server/models/gpt_neox.py
View file @
85aa7e2e
...
@@ -205,6 +205,15 @@ class GPTNeoxSharded(CausalLM):
...
@@ -205,6 +205,15 @@ class GPTNeoxSharded(CausalLM):
else
:
else
:
module
.
_buffers
[
param_name
]
=
tensor
module
.
_buffers
[
param_name
]
=
tensor
uninitialized_parameters
=
[]
for
n
,
p
in
model
.
named_parameters
():
if
p
.
data
.
device
==
torch
.
device
(
"meta"
):
uninitialized_parameters
.
append
(
n
)
if
uninitialized_parameters
:
raise
RuntimeError
(
f
"found uninitialized parameters in model:
{
uninitialized_parameters
}
"
)
def
forward
(
def
forward
(
self
,
input_ids
,
attention_mask
,
position_ids
,
past_key_values
:
Optional
=
None
self
,
input_ids
,
attention_mask
,
position_ids
,
past_key_values
:
Optional
=
None
):
):
...
...
server/text_generation_server/models/opt.py
View file @
85aa7e2e
...
@@ -210,6 +210,15 @@ class OPTSharded(OPT):
...
@@ -210,6 +210,15 @@ class OPTSharded(OPT):
if
name
==
"model.decoder.embed_tokens.weight"
:
if
name
==
"model.decoder.embed_tokens.weight"
:
model
.
lm_head
.
_parameters
[
"weight"
]
=
tensor
model
.
lm_head
.
_parameters
[
"weight"
]
=
tensor
uninitialized_parameters
=
[]
for
n
,
p
in
model
.
named_parameters
():
if
p
.
data
.
device
==
torch
.
device
(
"meta"
):
uninitialized_parameters
.
append
(
n
)
if
uninitialized_parameters
:
raise
RuntimeError
(
f
"found uninitialized parameters in model:
{
uninitialized_parameters
}
"
)
def
forward
(
def
forward
(
self
,
input_ids
,
attention_mask
,
position_ids
,
past_key_values
:
Optional
=
None
self
,
input_ids
,
attention_mask
,
position_ids
,
past_key_values
:
Optional
=
None
):
):
...
...
server/text_generation_server/models/t5.py
View file @
85aa7e2e
...
@@ -211,6 +211,15 @@ class T5Sharded(Seq2SeqLM):
...
@@ -211,6 +211,15 @@ class T5Sharded(Seq2SeqLM):
else
:
else
:
module
.
_buffers
[
param_name
]
=
tensor
module
.
_buffers
[
param_name
]
=
tensor
uninitialized_parameters
=
[]
for
n
,
p
in
model
.
named_parameters
():
if
p
.
data
.
device
==
torch
.
device
(
"meta"
):
uninitialized_parameters
.
append
(
n
)
if
uninitialized_parameters
:
raise
RuntimeError
(
f
"found uninitialized parameters in model:
{
uninitialized_parameters
}
"
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
,
input_ids
,
...
...
server/text_generation_server/utils/hub.py
View file @
85aa7e2e
...
@@ -77,7 +77,12 @@ def weight_files(
...
@@ -77,7 +77,12 @@ def weight_files(
"""Get the local files"""
"""Get the local files"""
# Local model
# Local model
if
Path
(
model_id
).
exists
()
and
Path
(
model_id
).
is_dir
():
if
Path
(
model_id
).
exists
()
and
Path
(
model_id
).
is_dir
():
return
list
(
Path
(
model_id
).
glob
(
f
"*
{
extension
}
"
))
local_files
=
list
(
Path
(
model_id
).
glob
(
f
"*
{
extension
}
"
))
if
not
local_files
:
raise
FileNotFoundError
(
f
"No local weights found in
{
model_id
}
with extension
{
extension
}
"
)
return
local_files
try
:
try
:
filenames
=
weight_hub_files
(
model_id
,
revision
,
extension
)
filenames
=
weight_hub_files
(
model_id
,
revision
,
extension
)
...
@@ -98,7 +103,7 @@ def weight_files(
...
@@ -98,7 +103,7 @@ def weight_files(
for
filename
in
filenames
:
for
filename
in
filenames
:
p
=
Path
(
WEIGHTS_CACHE_OVERRIDE
)
/
filename
p
=
Path
(
WEIGHTS_CACHE_OVERRIDE
)
/
filename
if
not
p
.
exists
():
if
not
p
.
exists
():
raise
LocalEntry
NotFoundError
(
raise
File
NotFoundError
(
f
"File
{
p
}
not found in
{
WEIGHTS_CACHE_OVERRIDE
}
."
f
"File
{
p
}
not found in
{
WEIGHTS_CACHE_OVERRIDE
}
."
)
)
files
.
append
(
p
)
files
.
append
(
p
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment