Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
f830706b
Unverified
Commit
f830706b
authored
Jan 31, 2023
by
OlivierDehaene
Committed by
GitHub
Jan 31, 2023
Browse files
feat(server): Support GPT-Neox (#39)
parent
c6e8b944
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
386 additions
and
47 deletions
+386
-47
README.md
README.md
+1
-0
launcher/src/main.rs
launcher/src/main.rs
+11
-0
server/tests/test_utils.py
server/tests/test_utils.py
+5
-1
server/text_generation/cli.py
server/text_generation/cli.py
+5
-2
server/text_generation/models/__init__.py
server/text_generation/models/__init__.py
+22
-9
server/text_generation/models/bloom.py
server/text_generation/models/bloom.py
+11
-5
server/text_generation/models/causal_lm.py
server/text_generation/models/causal_lm.py
+5
-2
server/text_generation/models/galactica.py
server/text_generation/models/galactica.py
+13
-12
server/text_generation/models/gpt_neox.py
server/text_generation/models/gpt_neox.py
+244
-0
server/text_generation/models/santacoder.py
server/text_generation/models/santacoder.py
+5
-2
server/text_generation/models/seq2seq_lm.py
server/text_generation/models/seq2seq_lm.py
+5
-2
server/text_generation/server.py
server/text_generation/server.py
+5
-3
server/text_generation/utils.py
server/text_generation/utils.py
+54
-9
No files found.
README.md
View file @
f830706b
...
...
@@ -26,6 +26,7 @@ to power Bloom, BloomZ and MT0-XXL api-inference widgets.
-
[
MT0-XXL
](
https://huggingface.co/bigscience/mt0-xxl
)
-
~~
[
Galactica
](
https://huggingface.co/facebook/galactica-120b
)
~~ (deactivated)
-
[
SantaCoder
](
https://huggingface.co/bigcode/santacoder
)
-
[
GPT-Neox 20B
](
https://huggingface.co/EleutherAI/gpt-neox-20b
)
: use
`--revision refs/pr/13`
Other models are supported on a best effort basis using:
...
...
launcher/src/main.rs
View file @
f830706b
...
...
@@ -21,6 +21,8 @@ struct Args {
#[clap(default_value
=
"bigscience/bloom-560m"
,
long,
env)]
model_name
:
String
,
#[clap(long,
env)]
revision
:
Option
<
String
>
,
#[clap(long,
env)]
num_shard
:
Option
<
usize
>
,
#[clap(long,
env)]
quantize
:
bool
,
...
...
@@ -48,6 +50,7 @@ fn main() -> ExitCode {
// Pattern match configuration
let
Args
{
model_name
,
revision
,
num_shard
,
quantize
,
max_concurrent_requests
,
...
...
@@ -90,6 +93,7 @@ fn main() -> ExitCode {
// Start shard processes
for
rank
in
0
..
num_shard
{
let
model_name
=
model_name
.clone
();
let
revision
=
revision
.clone
();
let
uds_path
=
shard_uds_path
.clone
();
let
master_addr
=
master_addr
.clone
();
let
status_sender
=
status_sender
.clone
();
...
...
@@ -98,6 +102,7 @@ fn main() -> ExitCode {
thread
::
spawn
(
move
||
{
shard_manager
(
model_name
,
revision
,
quantize
,
uds_path
,
rank
,
...
...
@@ -252,6 +257,7 @@ enum ShardStatus {
#[allow(clippy::too_many_arguments)]
fn
shard_manager
(
model_name
:
String
,
revision
:
Option
<
String
>
,
quantize
:
bool
,
uds_path
:
String
,
rank
:
usize
,
...
...
@@ -288,6 +294,11 @@ fn shard_manager(
shard_argv
.push
(
"--quantize"
.to_string
())
}
if
let
Some
(
revision
)
=
revision
{
shard_argv
.push
(
"--revision"
.to_string
());
shard_argv
.push
(
revision
)
}
let
mut
env
=
vec!
[
(
"RANK"
.into
(),
rank
.to_string
()
.into
()),
(
"WORLD_SIZE"
.into
(),
world_size
.to_string
()
.into
()),
...
...
server/tests/test_utils.py
View file @
f830706b
import
pytest
from
huggingface_hub.utils
import
RevisionNotFoundError
from
text_generation.utils
import
(
weight_hub_files
,
download_weights
,
...
...
@@ -51,7 +53,7 @@ def test_weight_hub_files_llm():
def
test_weight_hub_files_empty
():
filenames
=
weight_hub_files
(
"bigscience/bloom"
,
".errors"
)
filenames
=
weight_hub_files
(
"bigscience/bloom"
,
extension
=
".errors"
)
assert
filenames
==
[]
...
...
@@ -62,5 +64,7 @@ def test_download_weights():
def
test_weight_files_error
():
with
pytest
.
raises
(
RevisionNotFoundError
):
weight_files
(
"bigscience/bloom-560m"
,
revision
=
"error"
)
with
pytest
.
raises
(
LocalEntryNotFoundError
):
weight_files
(
"bert-base-uncased"
)
server/text_generation/cli.py
View file @
f830706b
...
...
@@ -4,6 +4,7 @@ import typer
from
pathlib
import
Path
from
loguru
import
logger
from
typing
import
Optional
from
text_generation
import
server
,
utils
...
...
@@ -13,6 +14,7 @@ app = typer.Typer()
@
app
.
command
()
def
serve
(
model_name
:
str
,
revision
:
Optional
[
str
]
=
None
,
sharded
:
bool
=
False
,
quantize
:
bool
=
False
,
uds_path
:
Path
=
"/tmp/text-generation"
,
...
...
@@ -44,15 +46,16 @@ def serve(
os
.
getenv
(
"MASTER_PORT"
,
None
)
is
not
None
),
"MASTER_PORT must be set when sharded is True"
server
.
serve
(
model_name
,
sharded
,
quantize
,
uds_path
)
server
.
serve
(
model_name
,
revision
,
sharded
,
quantize
,
uds_path
)
@
app
.
command
()
def
download_weights
(
model_name
:
str
,
revision
:
Optional
[
str
]
=
None
,
extension
:
str
=
".safetensors"
,
):
utils
.
download_weights
(
model_name
,
extension
)
utils
.
download_weights
(
model_name
,
revision
,
extension
)
if
__name__
==
"__main__"
:
...
...
server/text_generation/models/__init__.py
View file @
f830706b
import
torch
from
transformers
import
AutoConfig
from
typing
import
Optional
from
text_generation.models.model
import
Model
from
text_generation.models.causal_lm
import
CausalLM
from
text_generation.models.bloom
import
BLOOM
,
BLOOMSharded
from
text_generation.models.seq2seq_lm
import
Seq2SeqLM
from
text_generation.models.galactica
import
Galactica
,
GalacticaSharded
from
text_generation.models.santacoder
import
SantaCoder
from
text_generation.models.gpt_neox
import
GPTNeox
,
GPTNeoxSharded
__all__
=
[
"Model"
,
...
...
@@ -25,23 +29,32 @@ torch.backends.cuda.matmul.allow_tf32 = True
torch
.
backends
.
cudnn
.
allow_tf32
=
True
def
get_model
(
model_name
:
str
,
sharded
:
bool
,
quantize
:
bool
)
->
Model
:
if
model_name
.
startswith
(
"bigscience/bloom"
):
def
get_model
(
model_name
:
str
,
revision
:
Optional
[
str
],
sharded
:
bool
,
quantize
:
bool
)
->
Model
:
config
=
AutoConfig
.
from_pretrained
(
model_name
)
if
config
.
model_type
==
"bloom"
:
if
sharded
:
return
BLOOMSharded
(
model_name
,
revision
,
quantize
=
quantize
)
else
:
return
BLOOM
(
model_name
,
revision
,
quantize
=
quantize
)
elif
config
.
model_type
==
"gpt_neox"
:
if
sharded
:
return
BLOOM
Sharded
(
model_name
,
quantize
=
quantize
)
return
GPTNeox
Sharded
(
model_name
,
revision
,
quantize
=
quantize
)
else
:
return
BLOOM
(
model_name
,
quantize
=
quantize
)
return
GPTNeox
(
model_name
,
revision
,
quantize
=
quantize
)
elif
model_name
.
startswith
(
"facebook/galactica"
):
if
sharded
:
return
GalacticaSharded
(
model_name
,
quantize
=
quantize
)
return
GalacticaSharded
(
model_name
,
revision
,
quantize
=
quantize
)
else
:
return
Galactica
(
model_name
,
quantize
=
quantize
)
return
Galactica
(
model_name
,
revision
,
quantize
=
quantize
)
elif
"santacoder"
in
model_name
:
return
SantaCoder
(
model_name
,
quantize
)
return
SantaCoder
(
model_name
,
revision
,
quantize
)
else
:
if
sharded
:
raise
ValueError
(
"sharded is not supported for AutoModel"
)
try
:
return
CausalLM
(
model_name
,
quantize
=
quantize
)
return
CausalLM
(
model_name
,
revision
,
quantize
=
quantize
)
except
Exception
:
return
Seq2SeqLM
(
model_name
,
quantize
=
quantize
)
return
Seq2SeqLM
(
model_name
,
revision
,
quantize
=
quantize
)
server/text_generation/models/bloom.py
View file @
f830706b
...
...
@@ -56,7 +56,9 @@ class BLOOM(CausalLM):
class
BLOOMSharded
(
BLOOM
):
def
__init__
(
self
,
model_name
:
str
,
quantize
:
bool
=
False
):
def
__init__
(
self
,
model_name
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
bool
=
False
):
if
not
model_name
.
startswith
(
"bigscience/bloom"
):
raise
ValueError
(
f
"Model
{
model_name
}
is not supported"
)
...
...
@@ -69,19 +71,23 @@ class BLOOMSharded(BLOOM):
device
=
torch
.
device
(
"cpu"
)
dtype
=
torch
.
float32
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
padding_side
=
"left"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
revision
=
revision
,
padding_side
=
"left"
)
config
=
AutoConfig
.
from_pretrained
(
model_name
,
slow_but_exact
=
False
,
tp_parallel
=
True
model_name
,
revision
=
revision
,
slow_but_exact
=
False
,
tp_parallel
=
True
)
config
.
pad_token_id
=
3
# Only download weights for small models
if
self
.
master
and
model_name
==
"bigscience/bloom-560m"
:
download_weights
(
model_name
,
extension
=
".safetensors"
)
download_weights
(
model_name
,
revision
=
revision
,
extension
=
".safetensors"
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
filenames
=
weight_files
(
model_name
,
extension
=
".safetensors"
)
filenames
=
weight_files
(
model_name
,
revision
=
revision
,
extension
=
".safetensors"
)
if
not
filenames
:
raise
ValueError
(
"No safetensors weights found"
)
...
...
server/text_generation/models/causal_lm.py
View file @
f830706b
...
...
@@ -232,7 +232,7 @@ class CausalLMBatch(Batch):
class
CausalLM
(
Model
):
def
__init__
(
self
,
model_name
:
str
,
quantize
=
False
):
def
__init__
(
self
,
model_name
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
=
False
):
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
bfloat16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float32
...
...
@@ -243,9 +243,12 @@ class CausalLM(Model):
device
=
torch
.
device
(
"cpu"
)
dtype
=
torch
.
float32
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
padding_side
=
"left"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
revision
=
revision
,
padding_side
=
"left"
)
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_name
,
revision
=
revision
,
torch_dtype
=
dtype
,
device_map
=
"auto"
if
torch
.
cuda
.
is_available
()
else
None
,
load_in_8bit
=
quantize
,
...
...
server/text_generation/models/galactica.py
View file @
f830706b
...
...
@@ -148,7 +148,9 @@ class Galactica(CausalLM):
class
GalacticaSharded
(
Galactica
):
def
__init__
(
self
,
model_name
:
str
,
quantize
:
bool
=
False
):
def
__init__
(
self
,
model_name
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
bool
=
False
):
if
not
model_name
.
startswith
(
"facebook/galactica"
):
raise
ValueError
(
f
"Model
{
model_name
}
is not supported"
)
...
...
@@ -161,24 +163,23 @@ class GalacticaSharded(Galactica):
device
=
torch
.
device
(
"cpu"
)
dtype
=
torch
.
float32
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
padding_side
=
"left"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
revision
=
revision
,
padding_side
=
"left"
)
config
=
AutoConfig
.
from_pretrained
(
model_name
,
tp_parallel
=
True
)
config
=
AutoConfig
.
from_pretrained
(
model_name
,
revision
=
revision
,
tp_parallel
=
True
)
tokenizer
.
pad_token_id
=
config
.
pad_token_id
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
True
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch
.
backends
.
cudnn
.
allow_tf32
=
True
# Only download weights for small models
if
self
.
master
and
model_name
==
"facebook/galactica-125m"
:
download_weights
(
model_name
,
extension
=
".safetensors"
)
download_weights
(
model_name
,
revision
=
revision
,
extension
=
".safetensors"
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
filenames
=
weight_files
(
model_name
,
extension
=
".safetensors"
)
filenames
=
weight_files
(
model_name
,
revision
=
revision
,
extension
=
".safetensors"
)
if
not
filenames
:
raise
ValueError
(
"No safetensors weights found"
)
...
...
server/text_generation/models/gpt_neox.py
0 → 100644
View file @
f830706b
import
torch
import
torch.distributed
from
typing
import
List
,
Optional
,
Tuple
from
accelerate
import
init_empty_weights
from
safetensors
import
safe_open
from
transformers
import
(
AutoTokenizer
,
AutoModelForCausalLM
,
AutoConfig
,
)
from
transformers.models.gpt_neox.parallel_layers
import
(
TensorParallelColumnLinear
,
TensorParallelEmbedding
,
TensorParallelRowLinear
,
)
from
text_generation.models
import
CausalLM
from
text_generation.utils
import
(
initialize_torch_distributed
,
weight_files
,
download_weights
,
)
HAS_BITS_AND_BYTES
=
True
try
:
import
bitsandbytes
as
bnb
from
bitsandbytes.nn
import
Int8Params
except
Exception
as
e
:
HAS_BITS_AND_BYTES
=
False
class
GPTNeox
(
CausalLM
):
def
forward
(
self
,
input_ids
,
attention_mask
,
position_ids
,
past_key_values
:
Optional
=
None
)
->
Tuple
[
torch
.
Tensor
,
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]]:
"""Overwrite forward to ignore position_ids"""
# Model Forward
outputs
=
self
.
model
.
forward
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
past_key_values
=
past_key_values
,
use_cache
=
True
,
)
return
outputs
.
logits
,
outputs
.
past_key_values
class
GPTNeoxSharded
(
GPTNeox
):
def
__init__
(
self
,
model_name
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
bool
=
False
):
self
.
process_group
,
self
.
rank
,
self
.
world_size
=
initialize_torch_distributed
()
self
.
master
=
self
.
rank
==
0
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
f
"cuda:
{
self
.
rank
}
"
)
dtype
=
torch
.
bfloat16
else
:
device
=
torch
.
device
(
"cpu"
)
dtype
=
torch
.
float32
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
revision
=
revision
,
padding_side
=
"left"
)
tokenizer
.
pad_token
=
tokenizer
.
eos_token
config
=
AutoConfig
.
from_pretrained
(
model_name
,
revision
=
revision
,
tp_parallel
=
True
)
# Only master download weights
if
self
.
master
:
download_weights
(
model_name
,
revision
=
revision
,
extension
=
".safetensors"
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
filenames
=
weight_files
(
model_name
,
revision
=
revision
,
extension
=
".safetensors"
)
if
not
filenames
:
raise
ValueError
(
"No safetensors weights found"
)
with
init_empty_weights
():
model
=
AutoModelForCausalLM
.
from_config
(
config
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
self
.
load_weights
(
model
,
filenames
,
quantize
=
quantize
,
device
=
device
,
rank
=
self
.
rank
,
world_size
=
self
.
world_size
,
)
self
.
model
=
model
.
eval
().
to
(
dtype
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
super
(
CausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
device
=
device
,
)
@
staticmethod
def
load_weights
(
model
,
filenames
:
List
[
str
],
quantize
:
bool
,
device
:
torch
.
device
,
rank
:
int
,
world_size
:
int
,
):
parameters
=
dict
(
model
.
named_parameters
())
for
file
in
filenames
:
with
safe_open
(
file
,
framework
=
"pt"
,
device
=
str
(
device
)
if
not
quantize
else
"cpu"
)
as
f
:
for
name
in
f
.
keys
():
module_name
,
param_name
=
name
.
rsplit
(
"."
,
1
)
module
=
model
.
get_submodule
(
module_name
)
current_parameter_tensor
=
parameters
.
get
(
name
,
None
)
slice_
=
f
.
get_slice
(
name
)
if
isinstance
(
module
,
TensorParallelColumnLinear
):
size
=
slice_
.
get_shape
()[
0
]
block_size
=
size
//
world_size
start
=
rank
*
block_size
stop
=
(
rank
+
1
)
*
block_size
tensor
=
slice_
[
start
:
stop
]
elif
isinstance
(
module
,
TensorParallelRowLinear
):
if
param_name
==
"weight"
:
size
=
slice_
.
get_shape
()[
1
]
block_size
=
size
//
world_size
start
=
rank
*
block_size
stop
=
(
rank
+
1
)
*
block_size
tensor
=
slice_
[:,
start
:
stop
]
else
:
tensor
=
slice_
[:]
# XXX: Hack for Rowlinear to add the bias only once.
if
rank
!=
0
:
tensor
=
torch
.
zeros_like
(
tensor
)
elif
isinstance
(
module
,
TensorParallelEmbedding
):
size
=
slice_
.
get_shape
()[
0
]
block_size
=
size
//
world_size
start
=
rank
*
block_size
stop
=
(
rank
+
1
)
*
block_size
tensor
=
slice_
[
start
:
stop
]
elif
name
==
"embed_out.weight"
:
size
=
slice_
.
get_shape
()[
0
]
block_size
=
size
//
world_size
start
=
rank
*
block_size
stop
=
(
rank
+
1
)
*
block_size
tensor
=
slice_
[
start
:
stop
]
else
:
try
:
tensor
=
slice_
[:]
except
:
tensor
=
f
.
get_tensor
(
name
)
if
(
current_parameter_tensor
is
not
None
and
current_parameter_tensor
.
shape
!=
tensor
.
shape
):
raise
ValueError
(
f
"Name
{
name
}
-- Current
{
current_parameter_tensor
.
shape
}
and got
{
tensor
.
shape
}
"
)
tensor
=
tensor
.
contiguous
()
if
quantize
:
if
not
HAS_BITS_AND_BYTES
:
raise
ImportError
(
"bitsandbytes is not available on your machine either because it is not installed "
"or you don't have a GPU.
\n
"
"You can install it with `pip install bitsandbytes`."
)
if
(
type
(
module
)
in
[
TensorParallelRowLinear
,
TensorParallelColumnLinear
]
and
param_name
==
"weight"
):
tensor
=
Int8Params
(
tensor
,
has_fp16_weights
=
False
,
requires_grad
=
False
,
).
to
(
device
)
state
=
bnb
.
MatmulLtState
()
state
.
threshold
=
6.0
state
.
has_fp16_weights
=
False
state
.
memory_efficient_backward
=
False
state
.
use_pool
=
True
state
.
CB
=
tensor
.
CB
state
.
SCB
=
tensor
.
SCB
tensor
.
CB
=
None
tensor
.
SCB
=
None
def
replace_linear
(
state
):
def
linear
(
input
,
weight
,
bias
):
out
=
bnb
.
matmul
(
input
,
weight
,
state
=
state
,
threshold
=
state
.
threshold
,
bias
=
bias
,
)
if
state
.
CB
is
not
None
:
# we converted 8-bit row major to turing/ampere format
# in the first inference pass
# we no longer need the row-major weight
del
state
.
CB
weight
.
data
=
state
.
CxB
return
out
return
linear
module
.
linear
=
replace_linear
(
state
)
else
:
tensor
=
tensor
.
to
(
device
)
if
current_parameter_tensor
is
not
None
:
module
.
_parameters
[
param_name
]
=
tensor
else
:
module
.
_buffers
[
param_name
]
=
tensor
def
forward
(
self
,
input_ids
,
attention_mask
,
position_ids
,
past_key_values
:
Optional
=
None
):
outputs
=
self
.
model
.
forward
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
past_key_values
=
past_key_values
,
use_cache
=
True
,
)
# Logits are sharded, so we need to gather them
logits
=
[
torch
.
empty_like
(
outputs
.
logits
)
for
_
in
range
(
self
.
world_size
)]
torch
.
distributed
.
all_gather
(
logits
,
outputs
.
logits
,
group
=
self
.
process_group
)
logits
=
torch
.
cat
(
logits
,
dim
=
2
)
return
logits
,
outputs
.
past_key_values
server/text_generation/models/santacoder.py
View file @
f830706b
...
...
@@ -14,7 +14,7 @@ EOD = "<|endoftext|>"
class
SantaCoder
(
CausalLM
):
def
__init__
(
self
,
model_name
:
str
,
quantize
=
False
):
def
__init__
(
self
,
model_name
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
=
False
):
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
bfloat16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float32
...
...
@@ -25,7 +25,9 @@ class SantaCoder(CausalLM):
device
=
torch
.
device
(
"cpu"
)
dtype
=
torch
.
float32
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
padding_side
=
"left"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
revision
=
revision
,
padding_side
=
"left"
)
tokenizer
.
add_special_tokens
(
{
"additional_special_tokens"
:
[
...
...
@@ -42,6 +44,7 @@ class SantaCoder(CausalLM):
self
.
model
=
(
AutoModelForCausalLM
.
from_pretrained
(
model_name
,
revision
=
revision
,
torch_dtype
=
dtype
,
load_in_8bit
=
quantize
,
trust_remote_code
=
True
,
# required
...
...
server/text_generation/models/seq2seq_lm.py
View file @
f830706b
...
...
@@ -289,7 +289,7 @@ class Seq2SeqLMBatch(Batch):
class
Seq2SeqLM
(
Model
):
def
__init__
(
self
,
model_name
:
str
,
quantize
=
False
):
def
__init__
(
self
,
model_name
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
=
False
):
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
bfloat16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float32
...
...
@@ -302,11 +302,14 @@ class Seq2SeqLM(Model):
self
.
model
=
AutoModelForSeq2SeqLM
.
from_pretrained
(
model_name
,
revision
=
revision
,
torch_dtype
=
dtype
,
device_map
=
"auto"
if
torch
.
cuda
.
is_available
()
else
None
,
load_in_8bit
=
quantize
,
).
eval
()
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
padding_side
=
"left"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
revision
=
revision
,
padding_side
=
"left"
)
tokenizer
.
bos_token_id
=
self
.
model
.
config
.
decoder_start_token_id
super
(
Seq2SeqLM
,
self
).
__init__
(
...
...
server/text_generation/server.py
View file @
f830706b
...
...
@@ -6,7 +6,7 @@ from loguru import logger
from
grpc_reflection.v1alpha
import
reflection
from
pathlib
import
Path
from
typing
import
List
from
typing
import
List
,
Optional
from
text_generation.cache
import
Cache
from
text_generation.interceptor
import
ExceptionInterceptor
...
...
@@ -67,12 +67,14 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
def
serve
(
model_name
:
str
,
revision
:
Optional
[
str
],
sharded
:
bool
,
quantize
:
bool
,
uds_path
:
Path
,
):
async
def
serve_inner
(
model_name
:
str
,
revision
:
Optional
[
str
],
sharded
:
bool
=
False
,
quantize
:
bool
=
False
,
):
...
...
@@ -87,7 +89,7 @@ def serve(
local_url
=
unix_socket_template
.
format
(
uds_path
,
0
)
server_urls
=
[
local_url
]
model
=
get_model
(
model_name
,
sharded
,
quantize
)
model
=
get_model
(
model_name
,
revision
,
sharded
,
quantize
)
server
=
aio
.
server
(
interceptors
=
[
ExceptionInterceptor
()])
generate_pb2_grpc
.
add_TextGenerationServiceServicer_to_server
(
...
...
@@ -107,4 +109,4 @@ def serve(
logger
.
info
(
"Signal received. Shutting down"
)
await
server
.
stop
(
0
)
asyncio
.
run
(
serve_inner
(
model_name
,
sharded
,
quantize
))
asyncio
.
run
(
serve_inner
(
model_name
,
revision
,
sharded
,
quantize
))
server/text_generation/utils.py
View file @
f830706b
...
...
@@ -8,7 +8,9 @@ from datetime import timedelta
from
concurrent.futures
import
ThreadPoolExecutor
from
functools
import
partial
from
huggingface_hub
import
HfApi
,
hf_hub_download
,
try_to_load_from_cache
from
pathlib
import
Path
from
huggingface_hub
import
HfApi
,
hf_hub_download
,
_CACHED_NO_EXIST
from
huggingface_hub.constants
import
HUGGINGFACE_HUB_CACHE
from
huggingface_hub.utils
import
LocalEntryNotFoundError
from
tqdm
import
tqdm
from
typing
import
List
,
Optional
,
Tuple
...
...
@@ -170,20 +172,62 @@ def initialize_torch_distributed():
return
torch
.
distributed
.
distributed_c10d
.
_get_default_group
(),
rank
,
world_size
def
weight_hub_files
(
model_name
,
extension
=
".safetensors"
):
def
weight_hub_files
(
model_name
,
revision
=
None
,
extension
=
".safetensors"
):
"""Get the safetensors filenames on the hub"""
api
=
HfApi
()
info
=
api
.
model_info
(
model_name
)
info
=
api
.
model_info
(
model_name
,
revision
=
revision
)
filenames
=
[
s
.
rfilename
for
s
in
info
.
siblings
if
s
.
rfilename
.
endswith
(
extension
)]
return
filenames
def
weight_files
(
model_name
,
extension
=
".safetensors"
):
def
try_to_load_from_cache
(
model_name
,
revision
,
filename
):
"""Try to load a file from the Hugging Face cache"""
if
revision
is
None
:
revision
=
"main"
object_id
=
model_name
.
replace
(
"/"
,
"--"
)
repo_cache
=
Path
(
HUGGINGFACE_HUB_CACHE
)
/
f
"models--
{
object_id
}
"
if
not
repo_cache
.
is_dir
():
# No cache for this model
return
None
refs_dir
=
repo_cache
/
"refs"
snapshots_dir
=
repo_cache
/
"snapshots"
no_exist_dir
=
repo_cache
/
".no_exist"
# Resolve refs (for instance to convert main to the associated commit sha)
if
refs_dir
.
is_dir
():
revision_file
=
refs_dir
/
revision
if
revision_file
.
exists
():
with
revision_file
.
open
()
as
f
:
revision
=
f
.
read
()
# Check if file is cached as "no_exist"
if
(
no_exist_dir
/
revision
/
filename
).
is_file
():
return
_CACHED_NO_EXIST
# Check if revision folder exists
if
not
snapshots_dir
.
exists
():
return
None
cached_shas
=
os
.
listdir
(
snapshots_dir
)
if
revision
not
in
cached_shas
:
# No cache for this revision and we won't try to return a random revision
return
None
# Check if file exists in cache
cached_file
=
snapshots_dir
/
revision
/
filename
return
str
(
cached_file
)
if
cached_file
.
is_file
()
else
None
def
weight_files
(
model_name
,
revision
=
None
,
extension
=
".safetensors"
):
"""Get the local safetensors filenames"""
filenames
=
weight_hub_files
(
model_name
,
extension
)
filenames
=
weight_hub_files
(
model_name
,
revision
,
extension
)
files
=
[]
for
filename
in
filenames
:
cache_file
=
try_to_load_from_cache
(
model_name
,
filename
=
filename
)
cache_file
=
try_to_load_from_cache
(
model_name
,
revision
=
revision
,
filename
=
filename
)
if
cache_file
is
None
:
raise
LocalEntryNotFoundError
(
f
"File
{
filename
}
of model
{
model_name
}
not found in "
...
...
@@ -195,9 +239,9 @@ def weight_files(model_name, extension=".safetensors"):
return
files
def
download_weights
(
model_name
,
extension
=
".safetensors"
):
def
download_weights
(
model_name
,
revision
=
None
,
extension
=
".safetensors"
):
"""Download the safetensors files from the hub"""
filenames
=
weight_hub_files
(
model_name
,
extension
)
filenames
=
weight_hub_files
(
model_name
,
revision
,
extension
)
download_function
=
partial
(
hf_hub_download
,
...
...
@@ -207,7 +251,8 @@ def download_weights(model_name, extension=".safetensors"):
executor
=
ThreadPoolExecutor
(
max_workers
=
5
)
futures
=
[
executor
.
submit
(
download_function
,
filename
=
filename
)
for
filename
in
filenames
executor
.
submit
(
download_function
,
filename
=
filename
,
revision
=
revision
)
for
filename
in
filenames
]
files
=
[
future
.
result
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment