Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
0fbc6919
"csrc/cpu/knn_cpu.cpp" did not exist on "1111319de36da65aa4ff5df65803f6623633daf3"
Unverified
Commit
0fbc6919
authored
Feb 14, 2023
by
OlivierDehaene
Committed by
GitHub
Feb 14, 2023
Browse files
feat: add safetensors conversion (#63)
parent
9af45414
Changes
17
Show whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
749 additions
and
97 deletions
+749
-97
README.md
README.md
+4
-4
launcher/src/main.rs
launcher/src/main.rs
+140
-10
server/tests/utils/test_convert.py
server/tests/utils/test_convert.py
+17
-0
server/tests/utils/test_hub.py
server/tests/utils/test_hub.py
+40
-0
server/tests/utils/test_tokens.py
server/tests/utils/test_tokens.py
+1
-37
server/text_generation/cli.py
server/text_generation/cli.py
+48
-1
server/text_generation/models/__init__.py
server/text_generation/models/__init__.py
+20
-16
server/text_generation/models/bloom.py
server/text_generation/models/bloom.py
+0
-7
server/text_generation/models/galactica.py
server/text_generation/models/galactica.py
+0
-7
server/text_generation/models/gpt_neox.py
server/text_generation/models/gpt_neox.py
+0
-7
server/text_generation/models/santacoder.py
server/text_generation/models/santacoder.py
+1
-1
server/text_generation/models/t5.py
server/text_generation/models/t5.py
+0
-7
server/text_generation/utils/__init__.py
server/text_generation/utils/__init__.py
+36
-0
server/text_generation/utils/convert.py
server/text_generation/utils/convert.py
+96
-0
server/text_generation/utils/dist.py
server/text_generation/utils/dist.py
+35
-0
server/text_generation/utils/hub.py
server/text_generation/utils/hub.py
+169
-0
server/text_generation/utils/tokens.py
server/text_generation/utils/tokens.py
+142
-0
No files found.
README.md
View file @
0fbc6919
...
...
@@ -49,17 +49,17 @@ to power LLMs api-inference widgets.
-
Log probabilities
-
Distributed tracing with Open Telemetry
## Officially supported
model
s
## Officially supported
architecture
s
-
[
BLOOM
](
https://huggingface.co/bigscience/bloom
)
-
[
BLOOMZ
](
https://huggingface.co/bigscience/bloomz
)
-
[
MT0-XXL
](
https://huggingface.co/bigscience/mt0-xxl
)
-
~~
[
Galactica
](
https://huggingface.co/facebook/galactica-120b
)
~~ (deactivated)
-
[
Galactica
](
https://huggingface.co/facebook/galactica-120b
)
-
[
SantaCoder
](
https://huggingface.co/bigcode/santacoder
)
-
[
GPT-Neox 20B
](
https://huggingface.co/EleutherAI/gpt-neox-20b
)
-
[
FLAN-T5-XXL
](
https://huggingface.co/google/flan-t5-xxl
)
Other
model
s are supported on a best effort basis using:
Other
architecture
s are supported on a best effort basis using:
`AutoModelForCausalLM.from_pretrained(<model>, device_map="auto")`
...
...
@@ -191,7 +191,7 @@ Be aware that the official Docker image has them enabled by default.
### Download
First you ne
ed to download the weights:
It is advis
ed to download the weights
ahead of time with the following command
:
```
shell
make download-bloom
...
...
launcher/src/main.rs
View file @
0fbc6919
...
...
@@ -12,7 +12,7 @@ use std::thread;
use
std
::
thread
::
sleep
;
use
std
::
time
::{
Duration
,
Instant
};
use
std
::{
fs
,
io
};
use
subprocess
::{
Popen
,
PopenConfig
,
PopenError
,
Redirection
};
use
subprocess
::{
ExitStatus
,
Popen
,
PopenConfig
,
PopenError
,
Redirection
};
/// App Configuration
#[derive(Parser,
Debug)]
...
...
@@ -43,6 +43,10 @@ struct Args {
#[clap(default_value
=
"29500"
,
long,
env)]
master_port
:
usize
,
#[clap(long,
env)]
huggingface_hub_cache
:
Option
<
String
>
,
#[clap(long,
env)]
weights_cache_override
:
Option
<
String
>
,
#[clap(long,
env)]
json_output
:
bool
,
#[clap(long,
env)]
otlp_endpoint
:
Option
<
String
>
,
...
...
@@ -63,6 +67,8 @@ fn main() -> ExitCode {
shard_uds_path
,
master_addr
,
master_port
,
huggingface_hub_cache
,
weights_cache_override
,
json_output
,
otlp_endpoint
,
}
=
Args
::
parse
();
...
...
@@ -84,6 +90,124 @@ fn main() -> ExitCode {
})
.expect
(
"Error setting Ctrl-C handler"
);
// Download weights
if
weights_cache_override
.is_none
()
{
let
mut
download_argv
=
vec!
[
"text-generation-server"
.to_string
(),
"download-weights"
.to_string
(),
model_id
.clone
(),
"--logger-level"
.to_string
(),
"INFO"
.to_string
(),
"--json-output"
.to_string
(),
];
if
num_shard
==
1
{
download_argv
.push
(
"--extension"
.to_string
());
download_argv
.push
(
".bin"
.to_string
());
}
else
{
download_argv
.push
(
"--extension"
.to_string
());
download_argv
.push
(
".safetensors"
.to_string
());
}
// Model optional revision
if
let
Some
(
ref
revision
)
=
revision
{
download_argv
.push
(
"--revision"
.to_string
());
download_argv
.push
(
revision
.to_string
())
}
let
mut
env
=
Vec
::
new
();
// If the HUGGINGFACE_HUB_CACHE env var is set, pass it to the shard
// Useful when running inside a docker container
if
let
Some
(
ref
huggingface_hub_cache
)
=
huggingface_hub_cache
{
env
.push
((
"HUGGINGFACE_HUB_CACHE"
.into
(),
huggingface_hub_cache
.into
()));
};
// Start process
tracing
::
info!
(
"Starting download"
);
let
mut
download_process
=
match
Popen
::
create
(
&
download_argv
,
PopenConfig
{
stdout
:
Redirection
::
Pipe
,
stderr
:
Redirection
::
Pipe
,
// Needed for the shutdown procedure
setpgid
:
true
,
env
:
Some
(
env
),
..
Default
::
default
()
},
)
{
Ok
(
p
)
=>
p
,
Err
(
err
)
=>
{
if
let
PopenError
::
IoError
(
ref
err
)
=
err
{
if
err
.kind
()
==
io
::
ErrorKind
::
NotFound
{
tracing
::
error!
(
"text-generation-server not found in PATH"
);
tracing
::
error!
(
"Please install it with `make install-server`"
)
}
}
return
ExitCode
::
FAILURE
;
}
};
// Redirect STDOUT to the console
let
download_stdout
=
download_process
.stdout
.take
()
.unwrap
();
thread
::
spawn
(
move
||
{
// Enter download tracing span
let
stdout
=
BufReader
::
new
(
download_stdout
);
let
_
span
=
tracing
::
span!
(
tracing
::
Level
::
INFO
,
"download"
)
.entered
();
for
line
in
stdout
.lines
()
{
// Parse loguru logs
if
let
Ok
(
value
)
=
serde_json
::
from_str
::
<
Value
>
(
&
line
.unwrap
())
{
if
let
Some
(
text
)
=
value
.get
(
"text"
)
{
// Format escaped newlines
tracing
::
info!
(
"{}"
,
text
.to_string
()
.replace
(
"
\\
n"
,
""
));
}
}
}
});
loop
{
if
let
Some
(
status
)
=
download_process
.poll
()
{
match
status
{
ExitStatus
::
Exited
(
exit_code
)
=>
{
if
exit_code
==
0
{
tracing
::
info!
(
"Successfully downloaded weights."
);
break
;
}
else
{
let
mut
err
=
String
::
new
();
download_process
.stderr
.take
()
.unwrap
()
.read_to_string
(
&
mut
err
)
.unwrap
();
tracing
::
error!
(
"Download encountered an error: {err}"
);
return
ExitCode
::
FAILURE
;
}
}
_
=>
{
tracing
::
error!
(
"Download process exited with an unkown status."
);
return
ExitCode
::
FAILURE
;
}
}
}
if
!
running
.load
(
Ordering
::
SeqCst
)
{
download_process
.terminate
()
.unwrap
();
tracing
::
info!
(
"Waiting for download process to gracefully shutdown"
);
download_process
.wait_timeout
(
Duration
::
from_secs
(
90
))
.unwrap
();
tracing
::
info!
(
"Download process terminated"
);
return
ExitCode
::
SUCCESS
;
}
sleep
(
Duration
::
from_millis
(
100
));
}
}
else
{
tracing
::
info!
(
"weights_cache_override is set to {:?}."
,
weights_cache_override
);
tracing
::
info!
(
"Skipping download."
)
}
// Shared shutdown bool
let
shutdown
=
Arc
::
new
(
Mutex
::
new
(
false
));
// Shared shutdown channel
...
...
@@ -99,6 +223,8 @@ fn main() -> ExitCode {
let
revision
=
revision
.clone
();
let
uds_path
=
shard_uds_path
.clone
();
let
master_addr
=
master_addr
.clone
();
let
huggingface_hub_cache
=
huggingface_hub_cache
.clone
();
let
weights_cache_override
=
weights_cache_override
.clone
();
let
status_sender
=
status_sender
.clone
();
let
shutdown
=
shutdown
.clone
();
let
shutdown_sender
=
shutdown_sender
.clone
();
...
...
@@ -113,6 +239,8 @@ fn main() -> ExitCode {
num_shard
,
master_addr
,
master_port
,
huggingface_hub_cache
,
weights_cache_override
,
otlp_endpoint
,
status_sender
,
shutdown
,
...
...
@@ -232,7 +360,7 @@ fn main() -> ExitCode {
while
running
.load
(
Ordering
::
SeqCst
)
{
if
let
Ok
(
ShardStatus
::
Failed
((
rank
,
err
)))
=
status_receiver
.try_recv
()
{
tracing
::
error!
(
"Shard {} failed:
\n
{
}"
,
rank
,
err
);
tracing
::
error!
(
"Shard {
rank
} failed:
\n
{err
}"
);
exit_code
=
ExitCode
::
FAILURE
;
break
;
};
...
...
@@ -275,6 +403,8 @@ fn shard_manager(
world_size
:
usize
,
master_addr
:
String
,
master_port
:
usize
,
huggingface_hub_cache
:
Option
<
String
>
,
weights_cache_override
:
Option
<
String
>
,
otlp_endpoint
:
Option
<
String
>
,
status_sender
:
mpsc
::
Sender
<
ShardStatus
>
,
shutdown
:
Arc
<
Mutex
<
bool
>>
,
...
...
@@ -328,15 +458,15 @@ fn shard_manager(
(
"NCCL_ASYNC_ERROR_HANDLING"
.into
(),
"1"
.into
()),
];
// If
the HUGGINGFACE_HUB_CACHE env var
is se
t
, pass it to the shard
// If
huggingface_hub_cache
is s
om
e, pass it to the shard
// Useful when running inside a docker container
if
let
Ok
(
huggingface_hub_cache
)
=
env
::
var
(
"HUGGINGFACE_HUB_CACHE"
)
{
if
let
Some
(
huggingface_hub_cache
)
=
huggingface_hub_cache
{
env
.push
((
"HUGGINGFACE_HUB_CACHE"
.into
(),
huggingface_hub_cache
.into
()));
};
// If
the WEIGHTS_CACHE_OVERRIDE env var
is se
t
, pass it to the shard
// If
weights_cache_override
is s
om
e, pass it to the shard
// Useful when running inside a HuggingFace Inference Endpoint
if
let
Ok
(
weights_cache_override
)
=
env
::
var
(
"WEIGHTS_CACHE_OVERRIDE"
)
{
if
let
Some
(
weights_cache_override
)
=
weights_cache_override
{
env
.push
((
"WEIGHTS_CACHE_OVERRIDE"
.into
(),
weights_cache_override
.into
(),
...
...
@@ -355,7 +485,7 @@ fn shard_manager(
};
// Start process
tracing
::
info!
(
"Starting shard {
}"
,
rank
);
tracing
::
info!
(
"Starting shard {rank
}"
);
let
mut
p
=
match
Popen
::
create
(
&
shard_argv
,
PopenConfig
{
...
...
@@ -419,17 +549,17 @@ fn shard_manager(
if
*
shutdown
.lock
()
.unwrap
()
{
p
.terminate
()
.unwrap
();
let
_
=
p
.wait_timeout
(
Duration
::
from_secs
(
90
));
tracing
::
info!
(
"Shard {} terminated"
,
rank
);
tracing
::
info!
(
"Shard {
rank
} terminated"
);
return
;
}
// Shard is ready
if
uds
.exists
()
&&
!
ready
{
tracing
::
info!
(
"Shard {} ready in {:?}"
,
rank
,
start_time
.elapsed
());
tracing
::
info!
(
"Shard {
rank
} ready in {:?}"
,
start_time
.elapsed
());
status_sender
.send
(
ShardStatus
::
Ready
)
.unwrap
();
ready
=
true
;
}
else
if
!
ready
&&
wait_time
.elapsed
()
>
Duration
::
from_secs
(
10
)
{
tracing
::
info!
(
"Waiting for shard {} to be ready..."
,
rank
);
tracing
::
info!
(
"Waiting for shard {
rank
} to be ready..."
);
wait_time
=
Instant
::
now
();
}
sleep
(
Duration
::
from_millis
(
100
));
...
...
server/tests/utils/test_convert.py
0 → 100644
View file @
0fbc6919
from
text_generation.utils.hub
import
download_weights
,
weight_hub_files
,
weight_files
from
text_generation.utils.convert
import
convert_files
def
test_convert_files
():
model_id
=
"bigscience/bloom-560m"
pt_filenames
=
weight_hub_files
(
model_id
,
extension
=
".bin"
)
local_pt_files
=
download_weights
(
pt_filenames
,
model_id
)
local_st_files
=
[
p
.
parent
/
f
"
{
p
.
stem
.
lstrip
(
'pytorch_'
)
}
.safetensors"
for
p
in
local_pt_files
]
convert_files
(
local_pt_files
,
local_st_files
)
found_st_files
=
weight_files
(
model_id
)
assert
all
([
p
in
found_st_files
for
p
in
local_st_files
])
server/tests/utils/test_hub.py
0 → 100644
View file @
0fbc6919
import
pytest
from
text_generation.utils.hub
import
(
weight_hub_files
,
download_weights
,
weight_files
,
EntryNotFoundError
,
LocalEntryNotFoundError
,
RevisionNotFoundError
,
)
def
test_weight_hub_files
():
filenames
=
weight_hub_files
(
"bigscience/bloom-560m"
)
assert
filenames
==
[
"model.safetensors"
]
def
test_weight_hub_files_llm
():
filenames
=
weight_hub_files
(
"bigscience/bloom"
)
assert
filenames
==
[
f
"model_
{
i
:
05
d
}
-of-00072.safetensors"
for
i
in
range
(
1
,
73
)]
def
test_weight_hub_files_empty
():
with
pytest
.
raises
(
EntryNotFoundError
):
weight_hub_files
(
"bigscience/bloom"
,
extension
=
".errors"
)
def
test_download_weights
():
model_id
=
"bigscience/bloom-560m"
filenames
=
weight_hub_files
(
model_id
)
files
=
download_weights
(
filenames
,
model_id
)
local_files
=
weight_files
(
"bigscience/bloom-560m"
)
assert
files
==
local_files
def
test_weight_files_error
():
with
pytest
.
raises
(
RevisionNotFoundError
):
weight_files
(
"bigscience/bloom-560m"
,
revision
=
"error"
)
with
pytest
.
raises
(
LocalEntryNotFoundError
):
weight_files
(
"bert-base-uncased"
)
server/tests/
test_util
s.py
→
server/tests/
utils/test_token
s.py
View file @
0fbc6919
import
pytest
from
huggingface_hub.utils
import
RevisionNotFoundError
from
text_generation.utils
import
(
weight_hub_files
,
download_weights
,
weight_files
,
from
text_generation.utils.tokens
import
(
StopSequenceCriteria
,
StoppingCriteria
,
LocalEntryNotFoundError
,
FinishReason
,
)
...
...
@@ -41,31 +33,3 @@ def test_stopping_criteria_max():
assert
criteria
(
1
,
""
)
==
(
False
,
None
)
assert
criteria
(
1
,
""
)
==
(
False
,
None
)
assert
criteria
(
1
,
""
)
==
(
True
,
FinishReason
.
FINISH_REASON_LENGTH
)
def
test_weight_hub_files
():
filenames
=
weight_hub_files
(
"bigscience/bloom-560m"
)
assert
filenames
==
[
"model.safetensors"
]
def
test_weight_hub_files_llm
():
filenames
=
weight_hub_files
(
"bigscience/bloom"
)
assert
filenames
==
[
f
"model_
{
i
:
05
d
}
-of-00072.safetensors"
for
i
in
range
(
1
,
73
)]
def
test_weight_hub_files_empty
():
filenames
=
weight_hub_files
(
"bigscience/bloom"
,
extension
=
".errors"
)
assert
filenames
==
[]
def
test_download_weights
():
files
=
download_weights
(
"bigscience/bloom-560m"
)
local_files
=
weight_files
(
"bigscience/bloom-560m"
)
assert
files
==
local_files
def
test_weight_files_error
():
with
pytest
.
raises
(
RevisionNotFoundError
):
weight_files
(
"bigscience/bloom-560m"
,
revision
=
"error"
)
with
pytest
.
raises
(
LocalEntryNotFoundError
):
weight_files
(
"bert-base-uncased"
)
server/text_generation/cli.py
View file @
0fbc6919
...
...
@@ -60,8 +60,55 @@ def download_weights(
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
extension
:
str
=
".safetensors"
,
logger_level
:
str
=
"INFO"
,
json_output
:
bool
=
False
,
):
utils
.
download_weights
(
model_id
,
revision
,
extension
)
# Remove default handler
logger
.
remove
()
logger
.
add
(
sys
.
stdout
,
format
=
"{message}"
,
filter
=
"text_generation"
,
level
=
logger_level
,
serialize
=
json_output
,
backtrace
=
True
,
diagnose
=
False
,
)
# Test if files were already download
try
:
utils
.
weight_files
(
model_id
,
revision
,
extension
)
logger
.
info
(
"Files are already present in the local cache. "
"Skipping download."
)
return
# Local files not found
except
utils
.
LocalEntryNotFoundError
:
pass
# Download weights directly
try
:
filenames
=
utils
.
weight_hub_files
(
model_id
,
revision
,
extension
)
utils
.
download_weights
(
filenames
,
model_id
,
revision
)
except
utils
.
EntryNotFoundError
as
e
:
if
not
extension
==
".safetensors"
:
raise
e
logger
.
warning
(
f
"No safetensors weights found for model
{
model_id
}
at revision
{
revision
}
. "
f
"Converting PyTorch weights instead."
)
# Try to see if there are pytorch weights
pt_filenames
=
utils
.
weight_hub_files
(
model_id
,
revision
,
".bin"
)
# Download pytorch weights
local_pt_files
=
utils
.
download_weights
(
pt_filenames
,
model_id
,
revision
)
local_st_files
=
[
p
.
parent
/
f
"
{
p
.
stem
.
lstrip
(
'pytorch_'
)
}
.safetensors"
for
p
in
local_pt_files
]
# Convert pytorch weights to safetensors
utils
.
convert_files
(
local_pt_files
,
local_st_files
)
if
__name__
==
"__main__"
:
...
...
server/text_generation/models/__init__.py
View file @
0fbc6919
...
...
@@ -41,6 +41,15 @@ torch.set_grad_enabled(False)
def
get_model
(
model_id
:
str
,
revision
:
Optional
[
str
],
sharded
:
bool
,
quantize
:
bool
)
->
Model
:
if
model_id
.
startswith
(
"facebook/galactica"
):
if
sharded
:
return
GalacticaSharded
(
model_id
,
revision
,
quantize
=
quantize
)
else
:
return
Galactica
(
model_id
,
revision
,
quantize
=
quantize
)
if
"santacoder"
in
model_id
:
return
SantaCoder
(
model_id
,
revision
,
quantize
)
config
=
AutoConfig
.
from_pretrained
(
model_id
,
revision
=
revision
)
if
config
.
model_type
==
"bloom"
:
...
...
@@ -48,24 +57,19 @@ def get_model(
return
BLOOMSharded
(
model_id
,
revision
,
quantize
=
quantize
)
else
:
return
BLOOM
(
model_id
,
revision
,
quantize
=
quantize
)
elif
config
.
model_type
==
"gpt_neox"
:
if
config
.
model_type
==
"gpt_neox"
:
if
sharded
:
return
GPTNeoxSharded
(
model_id
,
revision
,
quantize
=
quantize
)
else
:
return
GPTNeox
(
model_id
,
revision
,
quantize
=
quantize
)
elif
config
.
model_type
==
"t5"
:
if
config
.
model_type
==
"t5"
:
if
sharded
:
return
T5Sharded
(
model_id
,
revision
,
quantize
=
quantize
)
else
:
return
Seq2SeqLM
(
model_id
,
revision
,
quantize
=
quantize
)
elif
model_id
.
startswith
(
"facebook/galactica"
):
if
sharded
:
return
GalacticaSharded
(
model_id
,
revision
,
quantize
=
quantize
)
else
:
return
Galactica
(
model_id
,
revision
,
quantize
=
quantize
)
elif
"santacoder"
in
model_id
:
return
SantaCoder
(
model_id
,
revision
,
quantize
)
else
:
if
sharded
:
raise
ValueError
(
"sharded is not supported for AutoModel"
)
try
:
...
...
server/text_generation/models/bloom.py
View file @
0fbc6919
...
...
@@ -23,7 +23,6 @@ from text_generation.pb import generate_pb2
from
text_generation.utils
import
(
initialize_torch_distributed
,
weight_files
,
download_weights
,
)
HAS_BITS_AND_BYTES
=
True
...
...
@@ -80,14 +79,8 @@ class BLOOMSharded(BLOOM):
)
config
.
pad_token_id
=
3
# Only download weights for small models
if
self
.
master
and
model_id
==
"bigscience/bloom-560m"
:
download_weights
(
model_id
,
revision
=
revision
,
extension
=
".safetensors"
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
filenames
=
weight_files
(
model_id
,
revision
=
revision
,
extension
=
".safetensors"
)
if
not
filenames
:
raise
ValueError
(
"No safetensors weights found"
)
with
init_empty_weights
():
model
=
AutoModelForCausalLM
.
from_config
(
config
)
...
...
server/text_generation/models/galactica.py
View file @
0fbc6919
...
...
@@ -26,7 +26,6 @@ from text_generation.utils import (
StoppingCriteria
,
initialize_torch_distributed
,
weight_files
,
download_weights
,
)
HAS_BITS_AND_BYTES
=
True
...
...
@@ -172,14 +171,8 @@ class GalacticaSharded(Galactica):
)
tokenizer
.
pad_token_id
=
config
.
pad_token_id
# Only download weights for small models
if
self
.
master
and
model_id
==
"facebook/galactica-125m"
:
download_weights
(
model_id
,
revision
=
revision
,
extension
=
".safetensors"
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
filenames
=
weight_files
(
model_id
,
revision
=
revision
,
extension
=
".safetensors"
)
if
not
filenames
:
raise
ValueError
(
"No safetensors weights found"
)
with
init_empty_weights
():
model
=
AutoModelForCausalLM
.
from_config
(
config
)
...
...
server/text_generation/models/gpt_neox.py
View file @
0fbc6919
...
...
@@ -20,7 +20,6 @@ from text_generation.models import CausalLM
from
text_generation.utils
import
(
initialize_torch_distributed
,
weight_files
,
download_weights
,
)
HAS_BITS_AND_BYTES
=
True
...
...
@@ -69,14 +68,8 @@ class GPTNeoxSharded(GPTNeox):
model_id
,
revision
=
revision
,
tp_parallel
=
True
)
# Only master download weights
if
self
.
master
:
download_weights
(
model_id
,
revision
=
revision
,
extension
=
".safetensors"
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
filenames
=
weight_files
(
model_id
,
revision
=
revision
,
extension
=
".safetensors"
)
if
not
filenames
:
raise
ValueError
(
"No safetensors weights found"
)
with
init_empty_weights
():
model
=
AutoModelForCausalLM
.
from_config
(
config
)
...
...
server/text_generation/models/santacoder.py
View file @
0fbc6919
import
torch
import
torch.distributed
from
typing
import
Optional
,
List
,
Tuple
from
typing
import
Optional
,
List
from
transformers
import
AutoTokenizer
,
AutoModelForCausalLM
from
text_generation.models
import
CausalLM
...
...
server/text_generation/models/t5.py
View file @
0fbc6919
...
...
@@ -20,7 +20,6 @@ from text_generation.models import Seq2SeqLM
from
text_generation.utils
import
(
initialize_torch_distributed
,
weight_files
,
download_weights
,
)
HAS_BITS_AND_BYTES
=
True
...
...
@@ -53,14 +52,8 @@ class T5Sharded(Seq2SeqLM):
)
tokenizer
.
bos_token_id
=
config
.
decoder_start_token_id
# Only master download weights
if
self
.
master
:
download_weights
(
model_id
,
revision
=
revision
,
extension
=
".safetensors"
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
filenames
=
weight_files
(
model_id
,
revision
=
revision
,
extension
=
".safetensors"
)
if
not
filenames
:
raise
ValueError
(
"No safetensors weights found"
)
with
init_empty_weights
():
model
=
AutoModelForSeq2SeqLM
.
from_config
(
config
)
...
...
server/text_generation/utils/__init__.py
0 → 100644
View file @
0fbc6919
from
text_generation.utils.convert
import
convert_file
,
convert_files
from
text_generation.utils.dist
import
initialize_torch_distributed
from
text_generation.utils.hub
import
(
weight_files
,
weight_hub_files
,
download_weights
,
EntryNotFoundError
,
LocalEntryNotFoundError
,
RevisionNotFoundError
,
)
from
text_generation.utils.tokens
import
(
Greedy
,
NextTokenChooser
,
Sampling
,
StoppingCriteria
,
StopSequenceCriteria
,
FinishReason
,
)
__all__
=
[
"convert_file"
,
"convert_files"
,
"initialize_torch_distributed"
,
"weight_files"
,
"weight_hub_files"
,
"download_weights"
,
"EntryNotFoundError"
,
"LocalEntryNotFoundError"
,
"RevisionNotFoundError"
,
"Greedy"
,
"NextTokenChooser"
,
"Sampling"
,
"StoppingCriteria"
,
"StopSequenceCriteria"
,
"FinishReason"
,
]
server/text_generation/utils/convert.py
0 → 100644
View file @
0fbc6919
import
concurrent
import
time
import
torch
from
concurrent.futures
import
ThreadPoolExecutor
from
collections
import
defaultdict
from
datetime
import
timedelta
from
loguru
import
logger
from
pathlib
import
Path
from
safetensors.torch
import
load_file
,
save_file
from
typing
import
Dict
,
List
def
check_file_size
(
source_file
:
Path
,
target_file
:
Path
):
"""
Check that two files are close in size
"""
source_file_size
=
source_file
.
stat
().
st_size
target_file_size
=
target_file
.
stat
().
st_size
if
(
source_file_size
-
target_file_size
)
/
source_file_size
>
0.01
:
raise
RuntimeError
(
f
"""The file size different is more than 1%:
-
{
source_file
}
:
{
source_file_size
}
-
{
target_file
}
:
{
target_file_size
}
"""
)
def
remove_shared_pointers
(
tensors
:
Dict
[
str
,
torch
.
Tensor
]):
"""
For a Dict of tensors, check if two or more tensors point to the same underlying memory and
remove them
"""
ptrs
=
defaultdict
(
list
)
for
k
,
v
in
tensors
.
items
():
ptrs
[
v
.
data_ptr
()].
append
(
k
)
# Iterate over all found memory addresses
for
ptr
,
names
in
ptrs
.
items
():
if
len
(
names
)
>
1
:
# Multiple tensors are point to the same memory
# Only keep the first tensor
for
name
in
names
[
1
:]:
tensors
.
pop
(
name
)
def
convert_file
(
pt_file
:
Path
,
st_file
:
Path
):
"""
Convert a pytorch file to a safetensors file
"""
pt_state
=
torch
.
load
(
pt_file
,
map_location
=
"cpu"
)
if
"state_dict"
in
pt_state
:
pt_state
=
pt_state
[
"state_dict"
]
remove_shared_pointers
(
pt_state
)
# Tensors need to be contiguous
pt_state
=
{
k
:
v
.
contiguous
()
for
k
,
v
in
pt_state
.
items
()}
st_file
.
parent
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
save_file
(
pt_state
,
str
(
st_file
),
metadata
=
{
"format"
:
"pt"
})
# Check that both files are close in size
check_file_size
(
pt_file
,
st_file
)
# Load safetensors state
st_state
=
load_file
(
str
(
st_file
))
for
k
in
st_state
:
pt_tensor
=
pt_state
[
k
]
st_tensor
=
st_state
[
k
]
if
not
torch
.
equal
(
pt_tensor
,
st_tensor
):
raise
RuntimeError
(
f
"The output tensors do not match for key
{
k
}
"
)
def
convert_files
(
pt_files
:
List
[
Path
],
st_files
:
List
[
Path
]):
assert
len
(
pt_files
)
==
len
(
st_files
)
executor
=
ThreadPoolExecutor
(
max_workers
=
5
)
futures
=
[
executor
.
submit
(
convert_file
,
pt_file
=
pt_file
,
st_file
=
st_file
)
for
pt_file
,
st_file
in
zip
(
pt_files
,
st_files
)
]
# We do this instead of using tqdm because we want to parse the logs with the launcher
logger
.
info
(
"Converting weights..."
)
start_time
=
time
.
time
()
for
i
,
future
in
enumerate
(
concurrent
.
futures
.
as_completed
(
futures
)):
elapsed
=
timedelta
(
seconds
=
int
(
time
.
time
()
-
start_time
))
remaining
=
len
(
futures
)
-
(
i
+
1
)
if
remaining
!=
0
:
eta
=
(
elapsed
/
(
i
+
1
))
*
remaining
else
:
eta
=
0
logger
.
info
(
f
"Convert: [
{
i
+
1
}
/
{
len
(
futures
)
}
] -- ETA:
{
eta
}
"
)
server/text_generation/utils/dist.py
0 → 100644
View file @
0fbc6919
import
os
import
torch
from
datetime
import
timedelta
def
initialize_torch_distributed
():
rank
=
int
(
os
.
getenv
(
"RANK"
,
"0"
))
world_size
=
int
(
os
.
getenv
(
"WORLD_SIZE"
,
"1"
))
if
torch
.
cuda
.
is_available
():
from
torch.distributed
import
ProcessGroupNCCL
# Set the device id.
assert
world_size
<=
torch
.
cuda
.
device_count
(),
"Each process is one gpu"
device
=
rank
%
torch
.
cuda
.
device_count
()
torch
.
cuda
.
set_device
(
device
)
backend
=
"nccl"
options
=
ProcessGroupNCCL
.
Options
()
options
.
is_high_priority_stream
=
True
options
.
_timeout
=
timedelta
(
seconds
=
60
)
else
:
backend
=
"gloo"
options
=
None
# Call the init process.
torch
.
distributed
.
init_process_group
(
backend
=
backend
,
world_size
=
world_size
,
rank
=
rank
,
timeout
=
timedelta
(
seconds
=
60
),
pg_options
=
options
,
)
return
torch
.
distributed
.
group
.
WORLD
,
rank
,
world_size
server/text_generation/utils/hub.py
0 → 100644
View file @
0fbc6919
import
time
import
concurrent
import
os
from
concurrent.futures
import
ThreadPoolExecutor
from
datetime
import
timedelta
from
loguru
import
logger
from
pathlib
import
Path
from
typing
import
Optional
,
List
from
huggingface_hub
import
HfApi
,
hf_hub_download
from
huggingface_hub.constants
import
HUGGINGFACE_HUB_CACHE
from
huggingface_hub.utils
import
(
LocalEntryNotFoundError
,
EntryNotFoundError
,
RevisionNotFoundError
,
# Import here to ease try/except in other part of the lib
)
WEIGHTS_CACHE_OVERRIDE
=
os
.
getenv
(
"WEIGHTS_CACHE_OVERRIDE"
,
None
)
def
weight_hub_files
(
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
extension
:
str
=
".safetensors"
)
->
List
[
str
]:
"""Get the weights filenames on the hub"""
api
=
HfApi
()
info
=
api
.
model_info
(
model_id
,
revision
=
revision
)
filenames
=
[
s
.
rfilename
for
s
in
info
.
siblings
if
s
.
rfilename
.
endswith
(
extension
)]
if
not
filenames
:
raise
EntryNotFoundError
(
f
"No
{
extension
}
weights found for model
{
model_id
}
and revision
{
revision
}
."
,
None
,
)
return
filenames
def
try_to_load_from_cache
(
model_id
:
str
,
revision
:
Optional
[
str
],
filename
:
str
)
->
Optional
[
Path
]:
"""Try to load a file from the Hugging Face cache"""
if
revision
is
None
:
revision
=
"main"
object_id
=
model_id
.
replace
(
"/"
,
"--"
)
repo_cache
=
Path
(
HUGGINGFACE_HUB_CACHE
)
/
f
"models--
{
object_id
}
"
if
not
repo_cache
.
is_dir
():
# No cache for this model
return
None
refs_dir
=
repo_cache
/
"refs"
snapshots_dir
=
repo_cache
/
"snapshots"
no_exist_dir
=
repo_cache
/
".no_exist"
# Resolve refs (for instance to convert main to the associated commit sha)
if
refs_dir
.
is_dir
():
revision_file
=
refs_dir
/
revision
if
revision_file
.
exists
():
with
revision_file
.
open
()
as
f
:
revision
=
f
.
read
()
# Check if file is cached as "no_exist"
if
(
no_exist_dir
/
revision
/
filename
).
is_file
():
return
None
# Check if revision folder exists
if
not
snapshots_dir
.
exists
():
return
None
cached_shas
=
os
.
listdir
(
snapshots_dir
)
if
revision
not
in
cached_shas
:
# No cache for this revision and we won't try to return a random revision
return
None
# Check if file exists in cache
cached_file
=
snapshots_dir
/
revision
/
filename
return
cached_file
if
cached_file
.
is_file
()
else
None
def
weight_files
(
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
extension
:
str
=
".safetensors"
)
->
List
[
Path
]:
"""Get the local files"""
try
:
filenames
=
weight_hub_files
(
model_id
,
revision
,
extension
)
except
EntryNotFoundError
as
e
:
if
extension
!=
".safetensors"
:
raise
e
# Try to see if there are pytorch weights
pt_filenames
=
weight_hub_files
(
model_id
,
revision
,
extension
=
".bin"
)
# Change pytorch extension to safetensors extension
# It is possible that we have safetensors weights locally even though they are not on the
# hub if we converted weights locally without pushing them
filenames
=
[
f
"
{
Path
(
f
).
stem
.
lstrip
(
'pytorch_'
)
}
.safetensors"
for
f
in
pt_filenames
]
if
WEIGHTS_CACHE_OVERRIDE
is
not
None
:
files
=
[]
for
filename
in
filenames
:
p
=
Path
(
WEIGHTS_CACHE_OVERRIDE
)
/
filename
if
not
p
.
exists
():
raise
LocalEntryNotFoundError
(
f
"File
{
p
}
not found in
{
WEIGHTS_CACHE_OVERRIDE
}
."
)
files
.
append
(
p
)
return
files
files
=
[]
for
filename
in
filenames
:
cache_file
=
try_to_load_from_cache
(
model_id
,
revision
=
revision
,
filename
=
filename
)
if
cache_file
is
None
:
raise
LocalEntryNotFoundError
(
f
"File
{
filename
}
of model
{
model_id
}
not found in "
f
"
{
os
.
getenv
(
'HUGGINGFACE_HUB_CACHE'
,
'the local cache'
)
}
. "
f
"Please run `text-generation-server download-weights
{
model_id
}
` first."
)
files
.
append
(
cache_file
)
return
files
def
download_weights
(
filenames
:
List
[
str
],
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
)
->
List
[
Path
]:
"""Download the safetensors files from the hub"""
def
download_file
(
filename
):
local_file
=
try_to_load_from_cache
(
model_id
,
revision
,
filename
)
if
local_file
is
not
None
:
logger
.
info
(
f
"File
{
filename
}
already present in cache."
)
return
local_file
start_time
=
time
.
time
()
local_file
=
hf_hub_download
(
filename
=
filename
,
repo_id
=
model_id
,
revision
=
revision
,
local_files_only
=
False
,
)
logger
.
info
(
f
"Downloaded
{
filename
}
at
{
local_file
}
in
{
timedelta
(
seconds
=
int
(
time
.
time
()
-
start_time
))
}
."
)
return
local_file
executor
=
ThreadPoolExecutor
(
max_workers
=
5
)
futures
=
[
executor
.
submit
(
download_file
,
filename
=
filename
)
for
filename
in
filenames
]
# We do this instead of using tqdm because we want to parse the logs with the launcher
logger
.
info
(
"Downloading weights..."
)
start_time
=
time
.
time
()
files
=
[]
for
i
,
future
in
enumerate
(
concurrent
.
futures
.
as_completed
(
futures
)):
elapsed
=
timedelta
(
seconds
=
int
(
time
.
time
()
-
start_time
))
remaining
=
len
(
futures
)
-
(
i
+
1
)
if
remaining
!=
0
:
eta
=
(
elapsed
/
(
i
+
1
))
*
remaining
else
:
eta
=
0
logger
.
info
(
f
"Download: [
{
i
+
1
}
/
{
len
(
futures
)
}
] -- ETA:
{
eta
}
"
)
files
.
append
(
Path
(
future
.
result
()))
return
[
Path
(
p
)
for
p
in
files
]
server/text_generation/utils.py
→
server/text_generation/utils
/tokens
.py
View file @
0fbc6919
import
concurrent
import
os
import
re
import
torch
import
torch.distributed
from
datetime
import
timedelta
from
concurrent.futures
import
ThreadPoolExecutor
from
functools
import
partial
from
pathlib
import
Path
from
huggingface_hub
import
HfApi
,
hf_hub_download
,
_CACHED_NO_EXIST
from
huggingface_hub.constants
import
HUGGINGFACE_HUB_CACHE
from
huggingface_hub.utils
import
LocalEntryNotFoundError
from
tqdm
import
tqdm
from
typing
import
List
,
Optional
,
Tuple
from
transformers
import
PreTrainedTokenizerBase
from
transformers.generation.logits_process
import
(
from
transformers
import
(
LogitsProcessorList
,
RepetitionPenaltyLogitsProcessor
,
TemperatureLogitsWarper
,
TopPLogitsWarper
,
TopKLogitsWarper
,
TopPLogitsWarper
,
RepetitionPenaltyLogitsProcessor
,
PreTrainedTokenizerBase
,
)
from
typing
import
List
,
Tuple
,
Optional
from
text_generation.pb
import
generate_pb2
from
text_generation.pb.generate_pb2
import
FinishReason
WEIGHTS_CACHE_OVERRIDE
=
os
.
getenv
(
"WEIGHTS_CACHE_OVERRIDE"
,
None
)
class
Sampling
:
def
__init__
(
self
,
seed
:
int
,
device
:
str
=
"cpu"
):
...
...
@@ -154,130 +140,3 @@ class StoppingCriteria:
return
StoppingCriteria
(
tokenizer
.
eos_token_id
,
stop_sequence_criterias
,
pb
.
max_new_tokens
)
def
initialize_torch_distributed
():
rank
=
int
(
os
.
getenv
(
"RANK"
,
"0"
))
world_size
=
int
(
os
.
getenv
(
"WORLD_SIZE"
,
"1"
))
if
torch
.
cuda
.
is_available
():
from
torch.distributed
import
ProcessGroupNCCL
# Set the device id.
assert
world_size
<=
torch
.
cuda
.
device_count
(),
"Each process is one gpu"
device
=
rank
%
torch
.
cuda
.
device_count
()
torch
.
cuda
.
set_device
(
device
)
backend
=
"nccl"
options
=
ProcessGroupNCCL
.
Options
()
options
.
is_high_priority_stream
=
True
options
.
_timeout
=
timedelta
(
seconds
=
60
)
else
:
backend
=
"gloo"
options
=
None
# Call the init process.
torch
.
distributed
.
init_process_group
(
backend
=
backend
,
world_size
=
world_size
,
rank
=
rank
,
timeout
=
timedelta
(
seconds
=
60
),
pg_options
=
options
,
)
return
torch
.
distributed
.
group
.
WORLD
,
rank
,
world_size
def
weight_hub_files
(
model_id
,
revision
=
None
,
extension
=
".safetensors"
):
"""Get the safetensors filenames on the hub"""
api
=
HfApi
()
info
=
api
.
model_info
(
model_id
,
revision
=
revision
)
filenames
=
[
s
.
rfilename
for
s
in
info
.
siblings
if
s
.
rfilename
.
endswith
(
extension
)]
return
filenames
def
try_to_load_from_cache
(
model_id
,
revision
,
filename
):
"""Try to load a file from the Hugging Face cache"""
if
revision
is
None
:
revision
=
"main"
object_id
=
model_id
.
replace
(
"/"
,
"--"
)
repo_cache
=
Path
(
HUGGINGFACE_HUB_CACHE
)
/
f
"models--
{
object_id
}
"
if
not
repo_cache
.
is_dir
():
# No cache for this model
return
None
refs_dir
=
repo_cache
/
"refs"
snapshots_dir
=
repo_cache
/
"snapshots"
no_exist_dir
=
repo_cache
/
".no_exist"
# Resolve refs (for instance to convert main to the associated commit sha)
if
refs_dir
.
is_dir
():
revision_file
=
refs_dir
/
revision
if
revision_file
.
exists
():
with
revision_file
.
open
()
as
f
:
revision
=
f
.
read
()
# Check if file is cached as "no_exist"
if
(
no_exist_dir
/
revision
/
filename
).
is_file
():
return
_CACHED_NO_EXIST
# Check if revision folder exists
if
not
snapshots_dir
.
exists
():
return
None
cached_shas
=
os
.
listdir
(
snapshots_dir
)
if
revision
not
in
cached_shas
:
# No cache for this revision and we won't try to return a random revision
return
None
# Check if file exists in cache
cached_file
=
snapshots_dir
/
revision
/
filename
return
str
(
cached_file
)
if
cached_file
.
is_file
()
else
None
def
weight_files
(
model_id
,
revision
=
None
,
extension
=
".safetensors"
):
"""Get the local safetensors filenames"""
if
WEIGHTS_CACHE_OVERRIDE
is
not
None
:
return
list
(
Path
(
WEIGHTS_CACHE_OVERRIDE
).
glob
(
f
"*
{
extension
}
"
))
filenames
=
weight_hub_files
(
model_id
,
revision
,
extension
)
files
=
[]
for
filename
in
filenames
:
cache_file
=
try_to_load_from_cache
(
model_id
,
revision
=
revision
,
filename
=
filename
)
if
cache_file
is
None
:
raise
LocalEntryNotFoundError
(
f
"File
{
filename
}
of model
{
model_id
}
not found in "
f
"
{
os
.
getenv
(
'HUGGINGFACE_HUB_CACHE'
,
'the local cache'
)
}
. "
f
"Please run `text-generation-server download-weights
{
model_id
}
` first."
)
files
.
append
(
cache_file
)
return
files
def
download_weights
(
model_id
,
revision
=
None
,
extension
=
".safetensors"
):
"""Download the safetensors files from the hub"""
if
WEIGHTS_CACHE_OVERRIDE
is
not
None
:
return
list
(
Path
(
WEIGHTS_CACHE_OVERRIDE
).
glob
(
f
"*
{
extension
}
"
))
filenames
=
weight_hub_files
(
model_id
,
revision
,
extension
)
download_function
=
partial
(
hf_hub_download
,
repo_id
=
model_id
,
local_files_only
=
False
,
)
executor
=
ThreadPoolExecutor
(
max_workers
=
5
)
futures
=
[
executor
.
submit
(
download_function
,
filename
=
filename
,
revision
=
revision
)
for
filename
in
filenames
]
files
=
[
future
.
result
()
for
future
in
tqdm
(
concurrent
.
futures
.
as_completed
(
futures
),
total
=
len
(
futures
))
]
return
files
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment