Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
c0aeb325
Unverified
Commit
c0aeb325
authored
Apr 03, 2023
by
OlivierDehaene
Committed by
GitHub
Apr 03, 2023
Browse files
feat(server): flash santacoder (#153)
parent
fef1a1c3
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
980 additions
and
455 deletions
+980
-455
server/text_generation_server/models/__init__.py
server/text_generation_server/models/__init__.py
+18
-8
server/text_generation_server/models/custom_modeling/__init__.py
...text_generation_server/models/custom_modeling/__init__.py
+0
-0
server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
...tion_server/models/custom_modeling/flash_neox_modeling.py
+0
-0
server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
...erver/models/custom_modeling/flash_santacoder_modeling.py
+357
-0
server/text_generation_server/models/flash_causal_lm.py
server/text_generation_server/models/flash_causal_lm.py
+458
-0
server/text_generation_server/models/flash_neox.py
server/text_generation_server/models/flash_neox.py
+7
-445
server/text_generation_server/models/flash_santacoder.py
server/text_generation_server/models/flash_santacoder.py
+138
-0
supported_models.json
supported_models.json
+2
-2
No files found.
server/text_generation_server/models/__init__.py
View file @
c0aeb325
...
@@ -8,6 +8,7 @@ from typing import Optional
...
@@ -8,6 +8,7 @@ from typing import Optional
from
text_generation_server.models.model
import
Model
from
text_generation_server.models.model
import
Model
from
text_generation_server.models.causal_lm
import
CausalLM
from
text_generation_server.models.causal_lm
import
CausalLM
from
text_generation_server.models.flash_causal_lm
import
FlashCausalLM
from
text_generation_server.models.bloom
import
BLOOM
,
BLOOMSharded
from
text_generation_server.models.bloom
import
BLOOM
,
BLOOMSharded
from
text_generation_server.models.seq2seq_lm
import
Seq2SeqLM
from
text_generation_server.models.seq2seq_lm
import
Seq2SeqLM
from
text_generation_server.models.galactica
import
Galactica
,
GalacticaSharded
from
text_generation_server.models.galactica
import
Galactica
,
GalacticaSharded
...
@@ -17,18 +18,22 @@ from text_generation_server.models.t5 import T5Sharded
...
@@ -17,18 +18,22 @@ from text_generation_server.models.t5 import T5Sharded
try
:
try
:
from
text_generation_server.models.flash_neox
import
FlashNeoX
,
FlashNeoXSharded
from
text_generation_server.models.flash_neox
import
FlashNeoX
,
FlashNeoXSharded
from
text_generation_server.models.flash_santacoder
import
FlashSantacoder
FLASH_NEOX
=
torch
.
cuda
.
is_available
()
and
int
(
os
.
environ
.
get
(
"FLASH_NEOX"
,
0
))
==
1
FLASH_ATTENTION
=
(
torch
.
cuda
.
is_available
()
and
int
(
os
.
environ
.
get
(
"FLASH_ATTENTION"
,
0
))
==
1
)
except
ImportError
:
except
ImportError
:
if
int
(
os
.
environ
.
get
(
"FLASH_
NEOX
"
,
0
))
==
1
:
if
int
(
os
.
environ
.
get
(
"FLASH_
ATTENTION
"
,
0
))
==
1
:
logger
.
exception
(
"Could not import Flash
NeoX
"
)
logger
.
exception
(
"Could not import Flash
Attention models
"
)
FLASH_
NEOX
=
False
FLASH_
ATTENTION
=
False
__all__
=
[
__all__
=
[
"Model"
,
"Model"
,
"BLOOM"
,
"BLOOM"
,
"BLOOMSharded"
,
"BLOOMSharded"
,
"CausalLM"
,
"CausalLM"
,
"FlashCausalLM"
,
"Galactica"
,
"Galactica"
,
"GalacticaSharded"
,
"GalacticaSharded"
,
"GPTNeoxSharded"
,
"GPTNeoxSharded"
,
...
@@ -38,9 +43,10 @@ __all__ = [
...
@@ -38,9 +43,10 @@ __all__ = [
"get_model"
,
"get_model"
,
]
]
if
FLASH_
NEOX
:
if
FLASH_
ATTENTION
:
__all__
.
append
(
FlashNeoX
)
__all__
.
append
(
FlashNeoX
)
__all__
.
append
(
FlashNeoXSharded
)
__all__
.
append
(
FlashNeoXSharded
)
__all__
.
append
(
FlashSantacoder
)
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
# in PyTorch 1.12 and later.
...
@@ -63,7 +69,11 @@ def get_model(
...
@@ -63,7 +69,11 @@ def get_model(
return
Galactica
(
model_id
,
revision
,
quantize
=
quantize
)
return
Galactica
(
model_id
,
revision
,
quantize
=
quantize
)
if
"santacoder"
in
model_id
:
if
"santacoder"
in
model_id
:
return
SantaCoder
(
model_id
,
revision
,
quantize
)
if
sharded
:
raise
NotImplementedError
(
"sharded is not supported for Santacoder"
)
else
:
santacoder_cls
=
FlashSantacoder
if
FLASH_ATTENTION
else
SantaCoder
return
santacoder_cls
(
model_id
,
revision
,
quantize
)
config
=
AutoConfig
.
from_pretrained
(
model_id
,
revision
=
revision
)
config
=
AutoConfig
.
from_pretrained
(
model_id
,
revision
=
revision
)
model_type
=
config
.
model_type
model_type
=
config
.
model_type
...
@@ -76,10 +86,10 @@ def get_model(
...
@@ -76,10 +86,10 @@ def get_model(
if
model_type
==
"gpt_neox"
:
if
model_type
==
"gpt_neox"
:
if
sharded
:
if
sharded
:
neox_cls
=
FlashNeoXSharded
if
FLASH_
NEOX
else
GPTNeoxSharded
neox_cls
=
FlashNeoXSharded
if
FLASH_
ATTENTION
else
GPTNeoxSharded
return
neox_cls
(
model_id
,
revision
,
quantize
=
quantize
)
return
neox_cls
(
model_id
,
revision
,
quantize
=
quantize
)
else
:
else
:
neox_cls
=
FlashNeoX
if
FLASH_
NEOX
else
CausalLM
neox_cls
=
FlashNeoX
if
FLASH_
ATTENTION
else
CausalLM
return
neox_cls
(
model_id
,
revision
,
quantize
=
quantize
)
return
neox_cls
(
model_id
,
revision
,
quantize
=
quantize
)
if
model_type
==
"t5"
:
if
model_type
==
"t5"
:
...
...
server/text_generation_server/models/custom_modeling/__init__.py
0 → 100644
View file @
c0aeb325
server/text_generation_server/models/flash_neox_modeling.py
→
server/text_generation_server/models/
custom_modeling/
flash_neox_modeling.py
View file @
c0aeb325
File moved
server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
0 → 100644
View file @
c0aeb325
import
torch
import
torch.distributed
from
torch
import
nn
from
transformers.activations
import
ACT2FN
# Flash attention imports
import
flash_attn_cuda
import
dropout_layer_norm
class
FastLayerNorm
(
nn
.
LayerNorm
):
def
forward
(
self
,
hidden_states
,
residual
=
None
):
if
hidden_states
.
shape
[
-
1
]
>
6144
:
if
residual
is
not
None
:
hidden_states
+=
residual
residual
=
hidden_states
return
super
(
FastLayerNorm
,
self
).
forward
(
hidden_states
),
residual
else
:
(
normed_hidden_states
,
residual
,
*
rest
,
)
=
dropout_layer_norm
.
dropout_add_ln_fwd
(
hidden_states
,
residual
,
self
.
weight
,
self
.
bias
,
None
,
None
,
None
,
None
,
0.0
,
self
.
eps
,
1.0
,
0
,
None
,
False
,
False
,
)
if
residual
is
None
:
residual
=
hidden_states
return
normed_hidden_states
,
residual
class
FastLinear
(
nn
.
Linear
):
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
bias
:
bool
=
True
,
device
=
None
,
dtype
=
None
,
)
->
None
:
super
(
FastLinear
,
self
).
__init__
(
in_features
,
out_features
,
bias
,
device
,
dtype
)
def
transpose_weight
(
self
):
self
.
weight
=
nn
.
Parameter
(
self
.
weight
.
T
)
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
bias
is
not
None
:
return
torch
.
addmm
(
self
.
bias
,
input
,
self
.
weight
)
return
torch
.
matmul
(
input
,
self
.
weight
)
class
FlashMQAttention
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
num_heads
,
hidden_size
,
process_group
=
None
,
):
super
().
__init__
()
self
.
num_heads
=
num_heads
self
.
hidden_size
=
hidden_size
self
.
head_size
=
hidden_size
//
num_heads
self
.
softmax_scale
=
self
.
head_size
**
(
-
0.5
)
if
process_group
is
None
:
self
.
attn
=
FastLinear
(
hidden_size
,
hidden_size
+
2
*
self
.
head_size
)
self
.
c_proj
=
FastLinear
(
hidden_size
,
hidden_size
)
else
:
raise
NotImplementedError
def
forward
(
self
,
hidden_states
,
cu_seqlens
,
max_s
,
layer_past
,
layer_past_present_indices
,
cu_seqlens_q
,
):
qkv
=
self
.
attn
(
hidden_states
)
# Split query from key_value
query
,
key_value
=
qkv
.
split
([
self
.
hidden_size
,
2
*
self
.
head_size
],
dim
=
1
)
# Prepare query and key_value for indexing
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key_value
=
key_value
.
view
(
-
1
,
2
,
1
,
self
.
head_size
)
# Prefill
if
layer_past_present_indices
is
None
:
# Copy to layer past
layer_past
[...]
=
key_value
# Expand from 1 to num_heads
key_value
=
key_value
.
expand
(
-
1
,
2
,
self
.
num_heads
,
self
.
head_size
)
# output
attn_output
=
torch
.
empty_like
(
query
)
# flash attention
flash_attn_cuda
.
fwd
(
query
,
key_value
[:,
0
],
key_value
[:,
1
],
attn_output
,
cu_seqlens
,
cu_seqlens
,
max_s
,
max_s
,
0.0
,
self
.
softmax_scale
,
False
,
True
,
False
,
0
,
None
,
)
# Decode
else
:
# Add present to the layer_past tensor at the correct indices
layer_past
[
layer_past_present_indices
]
=
key_value
# Expand from 1 to num_heads
key_value
=
layer_past
.
expand
(
-
1
,
2
,
self
.
num_heads
,
self
.
head_size
)
# output
attn_output
=
torch
.
empty_like
(
query
)
# flash attention
flash_attn_cuda
.
fwd
(
query
,
key_value
[:,
0
],
key_value
[:,
1
],
attn_output
,
cu_seqlens_q
,
cu_seqlens
,
1
,
max_s
,
0.0
,
self
.
softmax_scale
,
False
,
False
,
False
,
0
,
None
,
)
return
self
.
c_proj
(
attn_output
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
))
class
MLP
(
nn
.
Module
):
def
__init__
(
self
,
act
,
hidden_size
,
intermediate_size
,
process_group
=
None
):
super
().
__init__
()
self
.
act
=
(
ACT2FN
[
act
]
if
"gelu"
not
in
act
else
lambda
x
:
torch
.
nn
.
functional
.
gelu
(
x
,
approximate
=
"tanh"
if
act
in
[
"gelu_fast"
,
"gelu_pytorch_tanh"
]
else
None
)
)
if
process_group
is
None
:
self
.
c_fc
=
FastLinear
(
hidden_size
,
intermediate_size
)
self
.
c_proj
=
FastLinear
(
intermediate_size
,
hidden_size
)
else
:
raise
NotImplementedError
def
forward
(
self
,
hidden_states
):
hidden_states
=
self
.
c_fc
(
hidden_states
)
hidden_states
=
self
.
act
(
hidden_states
)
hidden_states
=
self
.
c_proj
(
hidden_states
)
return
hidden_states
class
Block
(
nn
.
Module
):
def
__init__
(
self
,
num_heads
,
act
,
hidden_size
,
intermediate_size
,
layer_norm_eps
,
process_group
=
None
,
):
super
().
__init__
()
self
.
ln_1
=
FastLayerNorm
(
hidden_size
,
eps
=
layer_norm_eps
)
self
.
ln_2
=
FastLayerNorm
(
hidden_size
,
eps
=
layer_norm_eps
)
self
.
attn
=
FlashMQAttention
(
num_heads
,
hidden_size
,
process_group
,
)
self
.
mlp
=
MLP
(
act
,
hidden_size
,
intermediate_size
,
process_group
,
)
def
forward
(
self
,
hidden_states
,
residual
,
cu_seqlens
,
max_s
,
layer_past
,
layer_past_present_indices
,
cu_seqlens_q
,
):
hidden_states
,
residual
=
self
.
ln_1
(
hidden_states
,
residual
)
hidden_states
=
self
.
attn
(
hidden_states
,
cu_seqlens
,
max_s
,
layer_past
,
layer_past_present_indices
,
cu_seqlens_q
,
)
hidden_states
,
residual
=
self
.
ln_2
(
hidden_states
,
residual
)
mlp_output
=
self
.
mlp
(
hidden_states
)
return
mlp_output
,
residual
class
FlashSantacoderModel
(
nn
.
Module
):
def
__init__
(
self
,
config
,
process_group
=
None
):
super
().
__init__
()
self
.
config
=
config
if
process_group
is
not
None
:
raise
NotImplementedError
self
.
wte
=
nn
.
Embedding
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
wpe
=
nn
.
Embedding
(
config
.
max_position_embeddings
,
config
.
hidden_size
)
self
.
h
=
nn
.
ModuleList
(
[
Block
(
config
.
num_attention_heads
,
config
.
activation_function
,
config
.
hidden_size
,
config
.
n_inner
if
config
.
n_inner
is
not
None
else
4
*
config
.
hidden_size
,
config
.
layer_norm_epsilon
,
process_group
,
)
for
_
in
range
(
config
.
num_hidden_layers
)
]
)
self
.
ln_f
=
FastLayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
head_size
=
self
.
h
[
0
].
attn
.
head_size
self
.
num_heads
=
self
.
h
[
0
].
attn
.
num_heads
def
post_load_weights
(
self
):
for
layer
in
self
.
h
:
layer
:
Block
layer
.
attn
.
attn
.
transpose_weight
()
layer
.
attn
.
c_proj
.
transpose_weight
()
layer
.
mlp
.
c_fc
.
transpose_weight
()
layer
.
mlp
.
c_proj
.
transpose_weight
()
def
forward
(
self
,
input_ids
,
position_ids
,
cu_seqlens
,
max_s
,
past_key_values
=
None
,
):
hidden_states
=
self
.
wte
(
input_ids
)
+
self
.
wpe
(
position_ids
)
# Prefill
if
past_key_values
is
None
:
# Create past tensor
past_key_values
=
hidden_states
.
new_empty
(
(
len
(
self
.
h
),
len
(
hidden_states
),
2
,
1
,
self
.
head_size
,
)
)
layer_past_present_indices
=
None
cu_seqlens_q
=
None
# Decode
else
:
# Create indices from cumulative sequence lengths
layer_past_present_indices
=
cu_seqlens
[
1
:]
-
1
cu_seqlens_q
=
torch
.
arange
(
cu_seqlens
.
shape
[
0
],
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
residual
=
None
for
i
,
layer
in
enumerate
(
self
.
h
):
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
,
cu_seqlens
,
max_s
,
past_key_values
[
i
],
layer_past_present_indices
,
cu_seqlens_q
,
)
hidden_states
,
_
=
self
.
ln_f
(
hidden_states
,
residual
)
return
hidden_states
,
past_key_values
class
FlashSantacoderForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
config
,
process_group
=
None
):
super
().
__init__
()
self
.
transformer
=
FlashSantacoderModel
(
config
,
process_group
)
self
.
lm_head
=
FastLinear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
)
def
post_load_weights
(
self
):
self
.
transformer
.
post_load_weights
()
self
.
lm_head
.
transpose_weight
()
def
forward
(
self
,
input_ids
,
position_ids
,
cu_seqlens
,
max_s
,
past_key_values
=
None
,
):
hidden_states
,
present
=
self
.
transformer
(
input_ids
,
position_ids
,
cu_seqlens
,
max_s
,
past_key_values
)
return
self
.
lm_head
(
hidden_states
),
present
server/text_generation_server/models/flash_causal_lm.py
0 → 100644
View file @
c0aeb325
import
torch
import
torch.distributed
from
torch.nn
import
functional
as
F
from
dataclasses
import
dataclass
from
opentelemetry
import
trace
from
transformers
import
AutoTokenizer
,
PreTrainedTokenizerBase
,
PreTrainedModel
from
typing
import
Optional
,
Tuple
,
List
,
Type
,
Union
from
text_generation_server.models
import
Model
from
text_generation_server.models.types
import
(
Batch
,
PrefillTokens
,
Generation
,
GeneratedText
,
)
from
text_generation_server.pb
import
generate_pb2
from
text_generation_server.utils
import
(
NextTokenChooser
,
StoppingCriteria
,
Sampling
,
)
tracer
=
trace
.
get_tracer
(
__name__
)
@
dataclass
class
FlashCausalLMBatch
(
Batch
):
batch_id
:
int
requests
:
List
[
generate_pb2
.
Request
]
# Decoder values
input_ids
:
torch
.
Tensor
position_ids
:
torch
.
Tensor
# cumulative sequence lengths
cu_seqlens
:
torch
.
Tensor
max_seqlen
:
int
past_key_values
:
Optional
[
torch
.
Tensor
]
# All tokens
all_input_ids
:
List
[
List
[
int
]]
all_input_ids_tensor
:
List
[
torch
.
Tensor
]
# Lengths of all generations present in the batch
input_lengths
:
List
[
int
]
# Generation helpers
next_token_choosers
:
List
[
NextTokenChooser
]
stopping_criterias
:
List
[
StoppingCriteria
]
def
to_pb
(
self
)
->
generate_pb2
.
Batch
:
return
generate_pb2
.
Batch
(
id
=
self
.
batch_id
,
requests
=
self
.
requests
,
size
=
len
(
self
)
)
@
classmethod
def
from_pb
(
cls
,
pb
:
generate_pb2
.
Batch
,
tokenizer
:
PreTrainedTokenizerBase
,
device
:
torch
.
device
,
)
->
"CausalLMBatch"
:
input_ids
=
[]
position_ids
=
[]
cu_seqlens
=
[
0
]
max_seqlen
=
0
input_lengths
=
[]
all_input_ids
=
[]
all_input_ids_tensor
=
[]
next_token_choosers
=
[]
stopping_criterias
=
[]
# Cumulative length
cumulative_length
=
0
# Parse batch
for
r
in
pb
.
requests
:
tokenized_input
=
tokenizer
(
r
.
inputs
)[
"input_ids"
]
input_length
=
len
(
tokenized_input
)
max_seqlen
=
max
(
max_seqlen
,
input_length
)
input_lengths
.
append
(
input_length
)
all_input_ids
.
append
(
tokenized_input
)
tokenized_input
=
torch
.
tensor
(
tokenized_input
,
device
=
device
)
input_ids
.
append
(
tokenized_input
)
# Position ids
position_ids
.
append
(
torch
.
arange
(
0
,
input_length
,
dtype
=
torch
.
int32
))
# Add cumulative lengths of all previous inputs
cu_seqlens
.
append
(
cumulative_length
+
input_length
)
next_token_choosers
.
append
(
NextTokenChooser
.
from_pb
(
r
.
parameters
,
device
))
stopping_criteria
=
StoppingCriteria
.
from_pb
(
r
.
stopping_parameters
,
tokenizer
)
stopping_criterias
.
append
(
stopping_criteria
)
all_input_ids_tensor
.
append
(
F
.
pad
(
tokenized_input
,
(
0
,
stopping_criteria
.
max_new_tokens
))
)
# Update
cumulative_length
+=
input_length
input_ids
=
torch
.
concat
(
input_ids
)
position_ids
=
torch
.
concat
(
position_ids
)
cu_seqlens
=
torch
.
tensor
(
cu_seqlens
,
dtype
=
torch
.
int32
)
return
cls
(
batch_id
=
pb
.
id
,
requests
=
pb
.
requests
,
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlens
=
cu_seqlens
,
max_seqlen
=
max_seqlen
,
past_key_values
=
None
,
input_lengths
=
input_lengths
,
all_input_ids
=
all_input_ids
,
all_input_ids_tensor
=
all_input_ids_tensor
,
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
)
@
classmethod
@
tracer
.
start_as_current_span
(
"concatenate"
)
def
concatenate
(
cls
,
batches
:
List
[
"FlashCausalLMBatch"
])
->
"FlashCausalLMBatch"
:
# Batch attributes
requests
=
[]
input_lengths
=
[]
all_input_ids
=
[]
all_input_ids_tensor
=
[]
next_token_choosers
=
[]
stopping_criterias
=
[]
# Batch tensors
input_ids
=
[]
position_ids
=
[]
cu_seqlens
=
[
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
)]
max_seqlen
=
0
past_key_values
=
[]
# Cumulative length
cumulative_length
=
torch
.
tensor
(
0
)
for
i
,
batch
in
enumerate
(
batches
):
requests
.
extend
(
batch
.
requests
)
input_lengths
.
extend
(
batch
.
input_lengths
)
all_input_ids
.
extend
(
batch
.
all_input_ids
)
all_input_ids_tensor
.
extend
(
batch
.
all_input_ids_tensor
)
next_token_choosers
.
extend
(
batch
.
next_token_choosers
)
stopping_criterias
.
extend
(
batch
.
stopping_criterias
)
# Add cumulative lengths of all previous inputs
cu_seqlens
.
append
(
batch
.
cu_seqlens
[
1
:]
+
cumulative_length
)
input_ids
.
append
(
batch
.
input_ids
)
position_ids
.
append
(
batch
.
position_ids
)
past_key_values
.
append
(
batch
.
past_key_values
)
max_seqlen
=
max
(
max_seqlen
,
batch
.
max_seqlen
)
# Update
cumulative_length
+=
batch
.
cu_seqlens
[
-
1
]
input_ids
=
torch
.
concat
(
input_ids
)
position_ids
=
torch
.
concat
(
position_ids
)
# Concat on dim=1 as first dim represents the model layers
past_key_values
=
torch
.
concat
(
past_key_values
,
dim
=
1
)
cu_seqlens
=
torch
.
concat
(
cu_seqlens
)
return
FlashCausalLMBatch
(
batch_id
=
batches
[
0
].
batch_id
,
requests
=
requests
,
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlens
=
cu_seqlens
,
max_seqlen
=
max_seqlen
,
past_key_values
=
past_key_values
,
input_lengths
=
input_lengths
,
all_input_ids
=
all_input_ids
,
all_input_ids_tensor
=
all_input_ids_tensor
,
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
)
def
__len__
(
self
):
return
len
(
self
.
requests
)
class
FlashCausalLM
(
Model
):
def
__init__
(
self
,
model_cls
:
Type
[
PreTrainedModel
],
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
=
False
,
):
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
bfloat16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float16
else
:
raise
NotImplementedError
(
"FlashCausalLM is only available on GPU"
)
if
quantize
:
raise
NotImplementedError
(
"FlashCausalLM does not support quantization"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_id
,
revision
=
revision
,
padding_side
=
"left"
)
self
.
model
=
(
model_cls
.
from_pretrained
(
model_id
,
revision
=
revision
,
torch_dtype
=
dtype
,
)
.
eval
()
.
cuda
()
)
super
(
FlashCausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
device
=
device
,
)
@
property
def
batch_type
(
self
)
->
Type
[
FlashCausalLMBatch
]:
return
FlashCausalLMBatch
def
decode
(
self
,
generated_ids
:
Union
[
torch
.
Tensor
,
List
[
int
]])
->
str
:
return
self
.
tokenizer
.
decode
(
generated_ids
,
skip_special_tokens
=
True
,
cleanup_tokenization_spaces
=
False
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
max_s
:
int
,
past_key_values
:
Optional
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Model Forward
return
self
.
model
.
forward
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlens
=
cu_seqlens
,
max_s
=
max_s
,
past_key_values
=
past_key_values
,
)
@
tracer
.
start_as_current_span
(
"generate_token"
)
def
generate_token
(
self
,
batch
:
FlashCausalLMBatch
)
->
Tuple
[
List
[
Generation
],
Optional
[
FlashCausalLMBatch
]]:
# Better to send to device here to avoid device issues in concatenate
position_ids
=
batch
.
position_ids
.
to
(
self
.
device
,
non_blocking
=
True
)
cu_seqlens
=
batch
.
cu_seqlens
.
to
(
self
.
device
)
out
,
present
=
self
.
forward
(
batch
.
input_ids
,
position_ids
,
cu_seqlens
,
batch
.
max_seqlen
,
batch
.
past_key_values
,
)
# List of indices to cache
next_batch_keep_indices
=
[]
# New values for next forward
next_batch_input_ids
=
[]
next_batch_position_ids
=
[]
next_batch_cu_seqlens
=
[
0
]
next_batch_max_seqlen
=
0
next_batch_past_key_values
=
[]
next_batch_input_lengths
=
[]
next_batch_all_input_ids
=
[]
next_batch_all_input_ids_tensor
=
[]
# Cumulative length
cumulative_length
=
0
# Results
generations
:
List
[
Generation
]
=
[]
# Zipped iterator
iterator
=
zip
(
batch
.
requests
,
batch
.
input_lengths
,
batch
.
next_token_choosers
,
batch
.
stopping_criterias
,
batch
.
all_input_ids
,
batch
.
all_input_ids_tensor
,
)
# For each member of the batch
for
i
,
(
request
,
input_length
,
next_token_chooser
,
stopping_criteria
,
all_input_ids
,
all_input_ids_tensor
,
)
in
enumerate
(
iterator
):
# Indexing metadata
start_index
=
cumulative_length
end_index
=
cumulative_length
+
input_length
if
batch
.
past_key_values
is
None
:
# Prefill mode
# out is of shape [cumulative_sequence_lengths, vocab_size]
logits
=
out
[
start_index
:
end_index
]
else
:
# Decode mode
# out is of shape [batch_size, vocab_size]
logits
=
out
[
i
].
unsqueeze
(
0
)
# Select next token
next_token_id
,
logprobs
=
next_token_chooser
(
all_input_ids_tensor
[
None
,
:
input_length
],
logits
)
next_token_id_squeezed
=
next_token_id
.
squeeze
()
next_token_id_item
=
next_token_id_squeezed
.
item
()
# Append next token to all tokens
all_input_ids
.
append
(
next_token_id_item
)
all_input_ids_tensor
[
input_length
]
=
next_token_id_item
new_input_length
=
input_length
+
1
# Generated token
next_token_logprob
=
logprobs
[
-
1
,
next_token_id_item
]
next_token_text
=
self
.
decode_token
(
next_token_id_item
,
)
# Evaluate stopping criteria
stop
,
reason
=
stopping_criteria
(
next_token_id_item
,
next_token_text
,
)
if
stop
:
# Decode generated tokens
output_text
=
self
.
decode
(
all_input_ids
[
-
stopping_criteria
.
current_tokens
:]
)
# Get seed
if
isinstance
(
next_token_chooser
.
choice
,
Sampling
):
seed
=
next_token_chooser
.
choice
.
seed
else
:
seed
=
None
generated_text
=
GeneratedText
(
output_text
,
stopping_criteria
.
current_tokens
,
reason
,
seed
)
else
:
# Keep request in the batch
next_batch_keep_indices
.
append
(
i
)
generated_text
=
None
# Get sequence present
seq_present
=
present
[:,
start_index
:
end_index
]
# Pad it for next iter attention
past
=
torch
.
nn
.
functional
.
pad
(
seq_present
,
(
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
))
next_batch_past_key_values
.
append
(
past
)
next_batch_input_ids
.
append
(
next_token_id
)
next_batch_position_ids
.
append
(
input_length
)
# Cumulative sum
next_batch_cu_seqlens
.
append
(
next_batch_cu_seqlens
[
-
1
]
+
new_input_length
)
next_batch_input_lengths
.
append
(
new_input_length
)
next_batch_all_input_ids
.
append
(
all_input_ids
)
next_batch_all_input_ids_tensor
.
append
(
all_input_ids_tensor
)
next_batch_max_seqlen
=
max
(
next_batch_max_seqlen
,
new_input_length
)
# Prefill
if
stopping_criteria
.
current_tokens
==
1
:
# Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs
=
[
float
(
"nan"
)]
+
logprobs
.
gather
(
1
,
all_input_ids_tensor
[
1
:
input_length
].
unsqueeze
(
1
)
).
squeeze
(
1
)[:
-
1
].
tolist
()
prefill_token_ids
=
all_input_ids
[:
-
1
]
prefill_texts
=
self
.
tokenizer
.
batch_decode
(
prefill_token_ids
,
clean_up_tokenization_spaces
=
False
,
skip_special_tokens
=
False
,
)
prefill_tokens
=
PrefillTokens
(
prefill_token_ids
,
prefill_logprobs
,
prefill_texts
)
else
:
prefill_tokens
=
None
generation
=
Generation
(
request
.
id
,
prefill_tokens
,
next_token_id_item
,
next_token_logprob
,
next_token_text
,
next_token_id_item
in
self
.
all_special_ids
,
generated_text
,
)
generations
.
append
(
generation
)
cumulative_length
+=
input_length
# We finished all generations in the batch; there is no next batch
if
not
next_batch_keep_indices
:
return
generations
,
None
# If we finished at least one generation, we need to evict the indices of the generations that finished
# from the values of the next batch
if
len
(
next_batch_keep_indices
)
!=
len
(
batch
):
# Apply indices to requests, token_choosers and stopping_criterias that need to be cached
next_batch_requests
=
[
batch
.
requests
[
i
]
for
i
in
next_batch_keep_indices
]
next_batch_next_token_choosers
=
[
batch
.
next_token_choosers
[
i
]
for
i
in
next_batch_keep_indices
]
next_batch_stopping_criterias
=
[
batch
.
stopping_criterias
[
i
]
for
i
in
next_batch_keep_indices
]
else
:
next_batch_requests
=
batch
.
requests
next_batch_next_token_choosers
=
batch
.
next_token_choosers
next_batch_stopping_criterias
=
batch
.
stopping_criterias
# Create final next batch tensors
next_batch_position_ids
=
torch
.
tensor
(
next_batch_position_ids
,
dtype
=
torch
.
int32
)
next_batch_cu_seqlens
=
torch
.
tensor
(
next_batch_cu_seqlens
,
dtype
=
torch
.
int32
)
if
len
(
next_batch_keep_indices
)
>
1
:
next_batch_input_ids
=
torch
.
concat
(
next_batch_input_ids
).
squeeze
(
1
)
next_batch_past_key_values
=
torch
.
concat
(
next_batch_past_key_values
,
dim
=
1
)
else
:
next_batch_input_ids
=
next_batch_input_ids
[
0
].
view
(
1
)
next_batch_past_key_values
=
next_batch_past_key_values
[
0
]
next_batch
=
FlashCausalLMBatch
(
batch_id
=
batch
.
batch_id
,
requests
=
next_batch_requests
,
input_ids
=
next_batch_input_ids
,
position_ids
=
next_batch_position_ids
,
cu_seqlens
=
next_batch_cu_seqlens
,
max_seqlen
=
next_batch_max_seqlen
,
past_key_values
=
next_batch_past_key_values
,
input_lengths
=
next_batch_input_lengths
,
all_input_ids
=
next_batch_all_input_ids
,
all_input_ids_tensor
=
next_batch_all_input_ids_tensor
,
next_token_choosers
=
next_batch_next_token_choosers
,
stopping_criterias
=
next_batch_stopping_criterias
,
)
return
generations
,
next_batch
server/text_generation_server/models/flash_neox.py
View file @
c0aeb325
import
torch
import
torch
import
torch.distributed
import
torch.distributed
from
torch.nn
import
functional
as
F
from
accelerate
import
init_empty_weights
from
accelerate
import
init_empty_weights
from
dataclasses
import
dataclass
from
opentelemetry
import
trace
from
opentelemetry
import
trace
from
safetensors
import
safe_open
from
safetensors
import
safe_open
from
transformers
import
AutoTokenizer
,
PreTrainedTokenizerBase
,
AutoConfig
from
transformers
import
AutoTokenizer
,
AutoConfig
from
typing
import
Optional
,
Tuple
,
List
,
Type
,
Union
from
typing
import
Optional
,
Tuple
,
List
from
text_generation_server.models
import
Model
from
text_generation_server.models
import
FlashCausalLM
from
text_generation_server.models.flash_neox_modeling
import
(
from
text_generation_server.models.
custom_modeling.
flash_neox_modeling
import
(
FlashGPTNeoXForCausalLM
,
FlashGPTNeoXForCausalLM
,
TensorParallelEmbedding
,
TensorParallelEmbedding
,
TensorParallelRowLinear
,
TensorParallelRowLinear
,
TensorParallelColumnLinear
,
TensorParallelColumnLinear
,
)
)
from
text_generation_server.models.types
import
(
Batch
,
PrefillTokens
,
Generation
,
GeneratedText
,
)
from
text_generation_server.pb
import
generate_pb2
from
text_generation_server.utils
import
(
from
text_generation_server.utils
import
(
NextTokenChooser
,
StoppingCriteria
,
Sampling
,
initialize_torch_distributed
,
initialize_torch_distributed
,
weight_files
,
weight_files
,
)
)
...
@@ -35,437 +22,12 @@ from text_generation_server.utils import (
...
@@ -35,437 +22,12 @@ from text_generation_server.utils import (
tracer
=
trace
.
get_tracer
(
__name__
)
tracer
=
trace
.
get_tracer
(
__name__
)
@
dataclass
class
FlashNeoX
(
FlashCausalLM
):
class
FlashNeoXBatch
(
Batch
):
batch_id
:
int
requests
:
List
[
generate_pb2
.
Request
]
# Decoder values
input_ids
:
torch
.
Tensor
position_ids
:
torch
.
Tensor
# cumulative sequence lengths
cu_seqlens
:
torch
.
Tensor
max_seqlen
:
int
past_key_values
:
Optional
[
torch
.
Tensor
]
# All tokens
all_input_ids
:
List
[
List
[
int
]]
all_input_ids_tensor
:
List
[
torch
.
Tensor
]
# Lengths of all generations present in the batch
input_lengths
:
List
[
int
]
# Generation helpers
next_token_choosers
:
List
[
NextTokenChooser
]
stopping_criterias
:
List
[
StoppingCriteria
]
def
to_pb
(
self
)
->
generate_pb2
.
Batch
:
return
generate_pb2
.
Batch
(
id
=
self
.
batch_id
,
requests
=
self
.
requests
,
size
=
len
(
self
)
)
@
classmethod
def
from_pb
(
cls
,
pb
:
generate_pb2
.
Batch
,
tokenizer
:
PreTrainedTokenizerBase
,
device
:
torch
.
device
,
)
->
"CausalLMBatch"
:
input_ids
=
[]
position_ids
=
[]
cu_seqlens
=
[
0
]
max_seqlen
=
0
input_lengths
=
[]
all_input_ids
=
[]
all_input_ids_tensor
=
[]
next_token_choosers
=
[]
stopping_criterias
=
[]
# Cumulative length
cumulative_length
=
0
# Parse batch
for
r
in
pb
.
requests
:
tokenized_input
=
tokenizer
(
r
.
inputs
)[
"input_ids"
]
input_length
=
len
(
tokenized_input
)
max_seqlen
=
max
(
max_seqlen
,
input_length
)
input_lengths
.
append
(
input_length
)
all_input_ids
.
append
(
tokenized_input
)
tokenized_input
=
torch
.
tensor
(
tokenized_input
,
device
=
device
)
input_ids
.
append
(
tokenized_input
)
# Position ids
position_ids
.
append
(
torch
.
arange
(
0
,
input_length
,
dtype
=
torch
.
int32
))
# Add cumulative lengths of all previous inputs
cu_seqlens
.
append
(
cumulative_length
+
input_length
)
next_token_choosers
.
append
(
NextTokenChooser
.
from_pb
(
r
.
parameters
,
device
))
stopping_criteria
=
StoppingCriteria
.
from_pb
(
r
.
stopping_parameters
,
tokenizer
)
stopping_criterias
.
append
(
stopping_criteria
)
all_input_ids_tensor
.
append
(
F
.
pad
(
tokenized_input
,
(
0
,
stopping_criteria
.
max_new_tokens
))
)
# Update
cumulative_length
+=
input_length
input_ids
=
torch
.
concat
(
input_ids
)
position_ids
=
torch
.
concat
(
position_ids
)
cu_seqlens
=
torch
.
tensor
(
cu_seqlens
,
dtype
=
torch
.
int32
)
return
cls
(
batch_id
=
pb
.
id
,
requests
=
pb
.
requests
,
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlens
=
cu_seqlens
,
max_seqlen
=
max_seqlen
,
past_key_values
=
None
,
input_lengths
=
input_lengths
,
all_input_ids
=
all_input_ids
,
all_input_ids_tensor
=
all_input_ids_tensor
,
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
)
@
classmethod
@
tracer
.
start_as_current_span
(
"concatenate"
)
def
concatenate
(
cls
,
batches
:
List
[
"CausalLMBatch"
])
->
"CausalLMBatch"
:
# Batch attributes
requests
=
[]
input_lengths
=
[]
all_input_ids
=
[]
all_input_ids_tensor
=
[]
next_token_choosers
=
[]
stopping_criterias
=
[]
# Batch tensors
input_ids
=
[]
position_ids
=
[]
cu_seqlens
=
[
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
)]
max_seqlen
=
0
past_key_values
=
[]
# Cumulative length
cumulative_length
=
torch
.
tensor
(
0
)
for
i
,
batch
in
enumerate
(
batches
):
requests
.
extend
(
batch
.
requests
)
input_lengths
.
extend
(
batch
.
input_lengths
)
all_input_ids
.
extend
(
batch
.
all_input_ids
)
all_input_ids_tensor
.
extend
(
batch
.
all_input_ids_tensor
)
next_token_choosers
.
extend
(
batch
.
next_token_choosers
)
stopping_criterias
.
extend
(
batch
.
stopping_criterias
)
# Add cumulative lengths of all previous inputs
cu_seqlens
.
append
(
batch
.
cu_seqlens
[
1
:]
+
cumulative_length
)
input_ids
.
append
(
batch
.
input_ids
)
position_ids
.
append
(
batch
.
position_ids
)
past_key_values
.
append
(
batch
.
past_key_values
)
max_seqlen
=
max
(
max_seqlen
,
batch
.
max_seqlen
)
# Update
cumulative_length
+=
batch
.
cu_seqlens
[
-
1
]
input_ids
=
torch
.
concat
(
input_ids
)
position_ids
=
torch
.
concat
(
position_ids
)
# Concat on dim=1 as first dim represents the model layers
past_key_values
=
torch
.
concat
(
past_key_values
,
dim
=
1
)
cu_seqlens
=
torch
.
concat
(
cu_seqlens
)
return
FlashNeoXBatch
(
batch_id
=
batches
[
0
].
batch_id
,
requests
=
requests
,
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlens
=
cu_seqlens
,
max_seqlen
=
max_seqlen
,
past_key_values
=
past_key_values
,
input_lengths
=
input_lengths
,
all_input_ids
=
all_input_ids
,
all_input_ids_tensor
=
all_input_ids_tensor
,
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
)
def
__len__
(
self
):
return
len
(
self
.
requests
)
class
FlashNeoX
(
Model
):
def
__init__
(
self
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
=
False
):
def
__init__
(
self
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
=
False
):
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
bfloat16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float16
else
:
raise
NotImplementedError
(
"FlashNeoX is only available on GPU"
)
if
quantize
:
raise
NotImplementedError
(
"FlashNeoX does not support quantization"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_id
,
revision
=
revision
,
padding_side
=
"left"
)
self
.
model
=
(
FlashGPTNeoXForCausalLM
.
from_pretrained
(
model_id
,
revision
=
revision
,
torch_dtype
=
dtype
,
)
.
eval
()
.
cuda
()
)
tokenizer
.
pad_token_id
=
(
self
.
model
.
config
.
pad_token_id
if
self
.
model
.
config
.
pad_token_id
is
not
None
else
self
.
model
.
config
.
eos_token_id
)
super
(
FlashNeoX
,
self
).
__init__
(
super
(
FlashNeoX
,
self
).
__init__
(
tokenizer
=
tokenizer
,
FlashGPTNeoXForCausalLM
,
model_id
,
revision
,
quantize
device
=
device
,
)
@
property
def
batch_type
(
self
)
->
Type
[
FlashNeoXBatch
]:
return
FlashNeoXBatch
def
decode
(
self
,
generated_ids
:
Union
[
torch
.
Tensor
,
List
[
int
]])
->
str
:
return
self
.
tokenizer
.
decode
(
generated_ids
,
skip_special_tokens
=
True
,
cleanup_tokenization_spaces
=
False
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
max_s
:
int
,
past_key_values
:
Optional
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Model Forward
return
self
.
model
.
forward
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlens
=
cu_seqlens
,
max_s
=
max_s
,
past_key_values
=
past_key_values
,
)
)
@
tracer
.
start_as_current_span
(
"generate_token"
)
def
generate_token
(
self
,
batch
:
FlashNeoXBatch
)
->
Tuple
[
List
[
Generation
],
Optional
[
FlashNeoXBatch
]]:
# Better to send to device here to avoid device issues in concatenate
position_ids
=
batch
.
position_ids
.
to
(
self
.
device
,
non_blocking
=
True
)
cu_seqlens
=
batch
.
cu_seqlens
.
to
(
self
.
device
)
out
,
present
=
self
.
forward
(
batch
.
input_ids
,
position_ids
,
cu_seqlens
,
batch
.
max_seqlen
,
batch
.
past_key_values
,
)
# List of indices to cache
next_batch_keep_indices
=
[]
# New values for next forward
next_batch_input_ids
=
[]
next_batch_position_ids
=
[]
next_batch_cu_seqlens
=
[
0
]
next_batch_max_seqlen
=
0
next_batch_past_key_values
=
[]
next_batch_input_lengths
=
[]
next_batch_all_input_ids
=
[]
next_batch_all_input_ids_tensor
=
[]
# Cumulative length
cumulative_length
=
0
# Results
generations
:
List
[
Generation
]
=
[]
# Zipped iterator
iterator
=
zip
(
batch
.
requests
,
batch
.
input_lengths
,
batch
.
next_token_choosers
,
batch
.
stopping_criterias
,
batch
.
all_input_ids
,
batch
.
all_input_ids_tensor
,
)
# For each member of the batch
for
i
,
(
request
,
input_length
,
next_token_chooser
,
stopping_criteria
,
all_input_ids
,
all_input_ids_tensor
,
)
in
enumerate
(
iterator
):
# Indexing metadata
start_index
=
cumulative_length
end_index
=
cumulative_length
+
input_length
if
batch
.
past_key_values
is
None
:
# Prefill mode
# out is of shape [cumulative_sequence_lengths, vocab_size]
logits
=
out
[
start_index
:
end_index
]
else
:
# Decode mode
# out is of shape [batch_size, vocab_size]
logits
=
out
[
i
].
unsqueeze
(
0
)
# Select next token
next_token_id
,
logprobs
=
next_token_chooser
(
all_input_ids_tensor
[
None
,
:
input_length
],
logits
)
next_token_id_squeezed
=
next_token_id
.
squeeze
()
next_token_id_item
=
next_token_id_squeezed
.
item
()
# Append next token to all tokens
all_input_ids
.
append
(
next_token_id_item
)
all_input_ids_tensor
[
input_length
]
=
next_token_id_item
new_input_length
=
input_length
+
1
# Generated token
next_token_logprob
=
logprobs
[
-
1
,
next_token_id_item
]
next_token_text
=
self
.
decode_token
(
next_token_id_item
,
)
# Evaluate stopping criteria
stop
,
reason
=
stopping_criteria
(
next_token_id_item
,
next_token_text
,
)
if
stop
:
# Decode generated tokens
output_text
=
self
.
decode
(
all_input_ids
[
-
stopping_criteria
.
current_tokens
:]
)
# Get seed
if
isinstance
(
next_token_chooser
.
choice
,
Sampling
):
seed
=
next_token_chooser
.
choice
.
seed
else
:
seed
=
None
generated_text
=
GeneratedText
(
output_text
,
stopping_criteria
.
current_tokens
,
reason
,
seed
)
else
:
# Keep request in the batch
next_batch_keep_indices
.
append
(
i
)
generated_text
=
None
# Get sequence present
seq_present
=
present
[:,
start_index
:
end_index
]
# Pad it for next iter attention
past
=
torch
.
nn
.
functional
.
pad
(
seq_present
,
(
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
))
next_batch_past_key_values
.
append
(
past
)
next_batch_input_ids
.
append
(
next_token_id
)
next_batch_position_ids
.
append
(
input_length
)
# Cumulative sum
next_batch_cu_seqlens
.
append
(
next_batch_cu_seqlens
[
-
1
]
+
new_input_length
)
next_batch_input_lengths
.
append
(
new_input_length
)
next_batch_all_input_ids
.
append
(
all_input_ids
)
next_batch_all_input_ids_tensor
.
append
(
all_input_ids_tensor
)
next_batch_max_seqlen
=
max
(
next_batch_max_seqlen
,
new_input_length
)
# Prefill
if
stopping_criteria
.
current_tokens
==
1
:
# Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs
=
[
float
(
"nan"
)]
+
logprobs
.
gather
(
1
,
all_input_ids_tensor
[
1
:
input_length
].
unsqueeze
(
1
)
).
squeeze
(
1
)[:
-
1
].
tolist
()
prefill_token_ids
=
all_input_ids
[:
-
1
]
prefill_texts
=
self
.
tokenizer
.
batch_decode
(
prefill_token_ids
,
clean_up_tokenization_spaces
=
False
,
skip_special_tokens
=
False
,
)
prefill_tokens
=
PrefillTokens
(
prefill_token_ids
,
prefill_logprobs
,
prefill_texts
)
else
:
prefill_tokens
=
None
generation
=
Generation
(
request
.
id
,
prefill_tokens
,
next_token_id_item
,
next_token_logprob
,
next_token_text
,
next_token_id_item
in
self
.
all_special_ids
,
generated_text
,
)
generations
.
append
(
generation
)
cumulative_length
+=
input_length
# We finished all generations in the batch; there is no next batch
if
not
next_batch_keep_indices
:
return
generations
,
None
# If we finished at least one generation, we need to evict the indices of the generations that finished
# from the values of the next batch
if
len
(
next_batch_keep_indices
)
!=
len
(
batch
):
# Apply indices to requests, token_choosers and stopping_criterias that need to be cached
next_batch_requests
=
[
batch
.
requests
[
i
]
for
i
in
next_batch_keep_indices
]
next_batch_next_token_choosers
=
[
batch
.
next_token_choosers
[
i
]
for
i
in
next_batch_keep_indices
]
next_batch_stopping_criterias
=
[
batch
.
stopping_criterias
[
i
]
for
i
in
next_batch_keep_indices
]
else
:
next_batch_requests
=
batch
.
requests
next_batch_next_token_choosers
=
batch
.
next_token_choosers
next_batch_stopping_criterias
=
batch
.
stopping_criterias
# Create final next batch tensors
next_batch_position_ids
=
torch
.
tensor
(
next_batch_position_ids
,
dtype
=
torch
.
int32
)
next_batch_cu_seqlens
=
torch
.
tensor
(
next_batch_cu_seqlens
,
dtype
=
torch
.
int32
)
if
len
(
next_batch_keep_indices
)
>
1
:
next_batch_input_ids
=
torch
.
concat
(
next_batch_input_ids
).
squeeze
(
1
)
next_batch_past_key_values
=
torch
.
concat
(
next_batch_past_key_values
,
dim
=
1
)
else
:
next_batch_input_ids
=
next_batch_input_ids
[
0
].
view
(
1
)
next_batch_past_key_values
=
next_batch_past_key_values
[
0
]
next_batch
=
FlashNeoXBatch
(
batch_id
=
batch
.
batch_id
,
requests
=
next_batch_requests
,
input_ids
=
next_batch_input_ids
,
position_ids
=
next_batch_position_ids
,
cu_seqlens
=
next_batch_cu_seqlens
,
max_seqlen
=
next_batch_max_seqlen
,
past_key_values
=
next_batch_past_key_values
,
input_lengths
=
next_batch_input_lengths
,
all_input_ids
=
next_batch_all_input_ids
,
all_input_ids_tensor
=
next_batch_all_input_ids_tensor
,
next_token_choosers
=
next_batch_next_token_choosers
,
stopping_criterias
=
next_batch_stopping_criterias
,
)
return
generations
,
next_batch
class
FlashNeoXSharded
(
FlashNeoX
):
class
FlashNeoXSharded
(
FlashNeoX
):
def
__init__
(
def
__init__
(
...
@@ -508,7 +70,7 @@ class FlashNeoXSharded(FlashNeoX):
...
@@ -508,7 +70,7 @@ class FlashNeoXSharded(FlashNeoX):
model
.
post_load_weights
()
model
.
post_load_weights
()
self
.
model
=
model
.
eval
().
to
(
dtype
)
self
.
model
=
model
.
eval
().
to
(
dtype
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
super
(
Flash
NeoX
,
self
).
__init__
(
super
(
Flash
CausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
device
=
device
,
device
=
device
,
)
)
...
...
server/text_generation_server/models/flash_santacoder.py
0 → 100644
View file @
c0aeb325
import
torch
import
torch.distributed
from
accelerate
import
init_empty_weights
from
opentelemetry
import
trace
from
pathlib
import
Path
from
transformers
import
AutoTokenizer
,
AutoConfig
from
typing
import
Optional
,
List
from
text_generation_server.models
import
FlashCausalLM
from
text_generation_server.models.custom_modeling.flash_santacoder_modeling
import
(
FlashSantacoderForCausalLM
)
from
text_generation_server.utils
import
(
weight_files
,
download_weights
,
weight_hub_files
,
LocalEntryNotFoundError
,
)
tracer
=
trace
.
get_tracer
(
__name__
)
class
FlashSantacoder
(
FlashCausalLM
):
def
__init__
(
self
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
=
False
):
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
bfloat16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float16
else
:
raise
NotImplementedError
(
"FlashSantacoder is only available on GPU"
)
if
quantize
:
raise
NotImplementedError
(
"FlashSantacoder does not support quantization"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_id
,
revision
=
revision
,
padding_side
=
"left"
)
config
=
AutoConfig
.
from_pretrained
(
model_id
,
revision
=
revision
,
trust_remote_code
=
True
# Needed as the config is not part of Transformers
)
# We do not use from_pretrained as we modified the model internal module layout
try
:
filenames
=
weight_files
(
model_id
,
revision
,
".bin"
)
# Local files not found
except
LocalEntryNotFoundError
:
hub_files
=
weight_hub_files
(
model_id
,
revision
,
".bin"
)
filenames
=
download_weights
(
hub_files
,
model_id
,
revision
)
with
init_empty_weights
():
model
=
FlashSantacoderForCausalLM
(
config
)
self
.
load_weights
(
model
,
filenames
,
)
self
.
model
=
model
.
eval
().
to
(
device
).
to
(
dtype
)
super
(
FlashCausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
device
=
device
,
)
@
staticmethod
def
load_weights
(
model
:
FlashSantacoderForCausalLM
,
filenames
:
List
[
Path
],
):
for
filename
in
filenames
:
state_dict
=
torch
.
load
(
filename
,
map_location
=
"cpu"
)
for
key
,
value
in
state_dict
.
items
():
layer_name
=
"."
.
join
(
key
.
split
(
"."
)[:
4
])
# Fused qkv
if
"q_attn.weight"
in
key
or
"kv_attn.weight"
in
key
:
final_key
=
layer_name
+
".attn.weight"
elif
"q_attn.bias"
in
key
or
"kv_attn.bias"
in
key
:
final_key
=
layer_name
+
".attn.bias"
else
:
final_key
=
key
module_name
,
param_name
=
final_key
.
rsplit
(
"."
,
1
)
module
=
model
.
get_submodule
(
module_name
)
try
:
current_parameter_tensor
=
module
.
_parameters
[
param_name
]
except
KeyError
:
current_parameter_tensor
=
None
if
current_parameter_tensor
is
not
None
:
if
"c_fc.weight"
in
key
or
"c_proj.weight"
in
key
or
"q_attn.weight"
in
key
or
"kv_attn.weight"
in
key
:
# Tranpose as we use nn.Linear instead of Conv1D
value
=
value
.
T
if
current_parameter_tensor
.
device
==
torch
.
device
(
"meta"
):
# Init qkv
if
"attn.weight"
in
final_key
:
module
.
_parameters
[
param_name
]
=
value
.
new_empty
(
(
model
.
transformer
.
head_size
*
(
model
.
transformer
.
num_heads
+
2
),
value
.
shape
[
1
])
)
elif
"attn.bias"
in
final_key
:
module
.
_parameters
[
param_name
]
=
value
.
new_empty
(
(
model
.
transformer
.
head_size
*
(
model
.
transformer
.
num_heads
+
2
))
)
# Copy to correct slice
if
"q_attn.weight"
in
key
:
module
.
_parameters
[
param_name
][:
value
.
shape
[
0
]]
=
value
elif
"q_attn.bias"
in
key
:
module
.
_parameters
[
param_name
][:
value
.
shape
[
0
]]
=
value
elif
"kv_attn.weight"
in
key
:
module
.
_parameters
[
param_name
][
model
.
transformer
.
head_size
*
model
.
transformer
.
num_heads
:
]
=
value
elif
"kv_attn.bias"
in
key
:
module
.
_parameters
[
param_name
][
model
.
transformer
.
head_size
*
model
.
transformer
.
num_heads
:
]
=
value
else
:
if
current_parameter_tensor
.
shape
!=
value
.
shape
:
raise
ValueError
(
f
"Name
{
final_key
}
-- Current
{
current_parameter_tensor
.
shape
}
and got
{
value
.
shape
}
"
)
module
.
_parameters
[
param_name
]
=
value
else
:
module
.
_buffers
[
param_name
]
=
value
torch
.
cuda
.
empty_cache
()
model
.
post_load_weights
()
def
decode
(
self
,
generated_ids
:
List
[
int
])
->
str
:
# Do not skip special tokens as they are used for custom parsing rules of the generated text
return
self
.
tokenizer
.
decode
(
generated_ids
,
skip_special_tokens
=
False
,
cleanup_tokenization_spaces
=
False
)
supported_models.json
View file @
c0aeb325
[
[
"bigcode/santacoder"
,
"bigscience/bloom"
,
"bigscience/bloom"
,
"bigscience/bloomz"
,
"bigscience/bloomz"
,
"EleutherAI/gpt-neox-20b"
,
"EleutherAI/gpt-neox-20b"
,
"google/flan-ul2"
,
"google/flan-ul2"
,
"google/flan-t5-xxl"
,
"google/flan-t5-xxl"
,
"OpenAssistant/oasst-sft-1-pythia-12b"
,
"OpenAssistant/oasst-sft-1-pythia-12b"
"olivierdehaene/optimized-santacoder"
]
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment