Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
3b56d766
Unverified
Commit
3b56d766
authored
Sep 28, 2023
by
OlivierDehaene
Committed by
GitHub
Sep 28, 2023
Browse files
feat: add mistral model (#1071)
parent
259a2300
Changes
27
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
954 additions
and
137 deletions
+954
-137
server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py
...n_server/models/custom_modeling/flash_mistral_modeling.py
+532
-0
server/text_generation_server/models/flash_causal_lm.py
server/text_generation_server/models/flash_causal_lm.py
+40
-133
server/text_generation_server/models/flash_mistral.py
server/text_generation_server/models/flash_mistral.py
+357
-0
server/text_generation_server/models/model.py
server/text_generation_server/models/model.py
+6
-0
server/text_generation_server/utils/flash_attn.py
server/text_generation_server/utils/flash_attn.py
+8
-0
server/text_generation_server/utils/layers.py
server/text_generation_server/utils/layers.py
+1
-0
update_doc.py
update_doc.py
+10
-4
No files found.
server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py
0 → 100644
View file @
3b56d766
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
torch.distributed
from
torch
import
nn
from
transformers.activations
import
ACT2FN
from
transformers.configuration_utils
import
PretrainedConfig
from
typing
import
Optional
,
List
,
Tuple
# Flash attention imports
import
dropout_layer_norm
# vllm imports
import
vllm_cache_ops
import
vllm_attention_ops
from
text_generation_server.utils.flash_attn
import
attention
,
HAS_FLASH_ATTN_V2
from
text_generation_server.utils.layers
import
(
TensorParallelRowLinear
,
TensorParallelColumnLinear
,
TensorParallelEmbedding
,
PositionRotaryEmbedding
,
TensorParallelHead
,
get_linear
,
)
if
not
HAS_FLASH_ATTN_V2
:
raise
ImportError
(
"Mistral model requires flash attn v2"
)
class
MistralConfig
(
PretrainedConfig
):
model_type
=
"mistral"
def
__init__
(
self
,
vocab_size
=
32000
,
hidden_size
=
4096
,
intermediate_size
=
14336
,
num_hidden_layers
=
32
,
num_attention_heads
=
32
,
num_key_value_heads
=
8
,
hidden_act
=
"silu"
,
max_position_embeddings
=
4096
*
32
,
initializer_range
=
0.02
,
rms_norm_eps
=
1e-6
,
use_cache
=
True
,
pad_token_id
=
None
,
bos_token_id
=
1
,
eos_token_id
=
2
,
pretraining_tp
=
1
,
tie_word_embeddings
=
False
,
rope_theta
=
10000.0
,
sliding_window
=
4096
,
**
kwargs
,
):
self
.
vocab_size
=
vocab_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
sliding_window
=
sliding_window
# for backward compatibility
if
num_key_value_heads
is
None
:
num_key_value_heads
=
num_attention_heads
self
.
num_key_value_heads
=
num_key_value_heads
self
.
hidden_act
=
hidden_act
self
.
initializer_range
=
initializer_range
self
.
rms_norm_eps
=
rms_norm_eps
self
.
pretraining_tp
=
pretraining_tp
self
.
use_cache
=
use_cache
self
.
rope_theta
=
rope_theta
super
().
__init__
(
pad_token_id
=
pad_token_id
,
bos_token_id
=
bos_token_id
,
eos_token_id
=
eos_token_id
,
tie_word_embeddings
=
tie_word_embeddings
,
**
kwargs
,
)
class
MistralRMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
prefix
,
weights
,
eps
=
1e-6
):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super
().
__init__
()
weight
=
weights
.
get_tensor
(
f
"
{
prefix
}
.weight"
)
self
.
weight
=
nn
.
Parameter
(
weight
)
self
.
variance_epsilon
=
eps
def
forward
(
self
,
hidden_states
,
residual
=
None
):
if
hidden_states
.
shape
[
-
1
]
>
8192
:
if
residual
is
not
None
:
hidden_states
+=
residual
residual
=
hidden_states
hidden_states
=
hidden_states
.
to
(
torch
.
float32
)
variance
=
hidden_states
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
hidden_states
=
hidden_states
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
# convert into half-precision if necessary
if
self
.
weight
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]:
hidden_states
=
hidden_states
.
to
(
self
.
weight
.
dtype
)
return
self
.
weight
*
hidden_states
,
residual
else
:
# faster post attention rms norm
normed_hidden_states
,
res
,
*
rest
=
dropout_layer_norm
.
dropout_add_ln_fwd
(
hidden_states
,
residual
,
self
.
weight
,
None
,
None
,
None
,
None
,
None
,
0.0
,
self
.
variance_epsilon
,
1.0
,
0
,
None
,
False
,
True
,
# Activate RMSNorm
)
if
res
is
None
:
res
=
hidden_states
return
normed_hidden_states
,
res
def
load_attention
(
config
,
prefix
,
weights
):
if
config
.
num_attention_heads
!=
config
.
num_key_value_heads
:
return
_load_gqa
(
config
,
prefix
,
weights
)
else
:
return
TensorParallelColumnLinear
.
load_multi
(
config
,
prefixes
=
[
f
"
{
prefix
}
.q_proj"
,
f
"
{
prefix
}
.k_proj"
,
f
"
{
prefix
}
.v_proj"
],
dim
=
0
,
weights
=
weights
,
bias
=
False
,
)
def
_load_gqa
(
config
,
prefix
:
str
,
weights
):
assert
config
.
hidden_size
%
config
.
num_attention_heads
==
0
assert
config
.
num_attention_heads
%
weights
.
process_group
.
size
()
==
0
weight
=
weights
.
get_multi_weights_col
(
prefixes
=
[
f
"
{
prefix
}
.q_proj"
,
f
"
{
prefix
}
.k_proj"
,
f
"
{
prefix
}
.v_proj"
],
quantize
=
config
.
quantize
,
dim
=
0
,
)
if
config
.
quantize
not
in
[
"gptq"
,
"awq"
]:
weight
=
weight
.
to
(
dtype
=
weights
.
dtype
).
to
(
device
=
weights
.
device
)
head_size
=
config
.
hidden_size
//
config
.
num_attention_heads
num_heads
=
config
.
num_attention_heads
//
weights
.
process_group
.
size
()
num_key_value_heads
=
config
.
num_key_value_heads
//
weights
.
process_group
.
size
()
assert
list
(
weight
.
shape
)
==
[
(
num_heads
+
2
*
num_key_value_heads
)
*
head_size
,
config
.
hidden_size
,
],
f
"
{
list
(
weight
.
shape
)
}
!=
{
[(
num_heads
+
2
*
config
.
num_key_value_heads
)
*
head_size
,
config
.
hidden_size
]
}
"
return
TensorParallelColumnLinear
(
get_linear
(
weight
,
bias
=
None
,
quantize
=
config
.
quantize
)
)
class
MistralAttention
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
prefix
:
str
,
config
,
weights
,
):
super
().
__init__
()
self
.
max_past
=
(
config
.
sliding_window
if
config
.
sliding_window
is
not
None
else
0
)
self
.
num_heads
=
config
.
num_attention_heads
self
.
hidden_size
=
config
.
hidden_size
self
.
head_size
=
self
.
hidden_size
//
self
.
num_heads
self
.
rotary_emb
=
PositionRotaryEmbedding
.
static
(
config
=
config
,
dim
=
self
.
head_size
,
base
=
config
.
rope_theta
,
device
=
weights
.
device
,
)
self
.
softmax_scale
=
self
.
head_size
**-
0.5
if
self
.
num_heads
%
weights
.
process_group
.
size
()
!=
0
:
raise
ValueError
(
f
"`num_heads` must be divisible by `num_shards` (got `num_heads`:
{
self
.
num_heads
}
"
f
"and `num_shards`:
{
weights
.
process_group
.
size
()
}
"
)
self
.
num_heads
=
self
.
num_heads
//
weights
.
process_group
.
size
()
self
.
num_key_value_heads
=
(
config
.
num_key_value_heads
//
weights
.
process_group
.
size
()
)
self
.
query_key_value
=
load_attention
(
config
,
prefix
,
weights
)
self
.
o_proj
=
TensorParallelRowLinear
.
load
(
config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
weights
=
weights
,
bias
=
False
,
)
self
.
num_groups
=
self
.
num_heads
//
self
.
num_key_value_heads
self
.
kv_head_mapping
=
torch
.
arange
(
0
,
self
.
num_key_value_heads
,
dtype
=
torch
.
int32
,
device
=
weights
.
device
).
repeat_interleave
(
self
.
num_groups
)
def
forward
(
self
,
hidden_states
,
cos
,
sin
,
cu_seqlen_prefill
,
kv_cache
,
block_tables
,
slots
,
input_lengths
,
max_s
,
prefill_cache_indices
,
):
qkv
=
self
.
query_key_value
(
hidden_states
)
query
,
kv
=
qkv
.
split
(
[
self
.
head_size
*
self
.
num_heads
,
2
*
self
.
head_size
*
self
.
num_key_value_heads
,
],
dim
=
1
,
)
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
kv
=
kv
.
view
(
-
1
,
2
,
self
.
num_key_value_heads
,
self
.
head_size
)
self
.
rotary_emb
(
query
,
cos
,
sin
)
self
.
rotary_emb
(
torch
.
select
(
kv
,
dim
=
1
,
index
=
0
),
cos
,
sin
)
if
prefill_cache_indices
is
not
None
:
kv_to_cache
=
kv
[
prefill_cache_indices
]
else
:
kv_to_cache
=
kv
vllm_cache_ops
.
reshape_and_cache
(
kv_to_cache
[:,
0
],
kv_to_cache
[:,
1
],
kv_cache
[
0
],
kv_cache
[
1
],
slots
)
# output tensor
attn_output
=
torch
.
empty_like
(
query
)
# Prefill
if
cu_seqlen_prefill
is
not
None
:
# flash attention
attention
(
query
,
torch
.
select
(
kv
,
dim
=
1
,
index
=
0
),
torch
.
select
(
kv
,
dim
=
1
,
index
=
1
),
attn_output
,
cu_seqlen_prefill
,
max_s
,
self
.
softmax_scale
,
window_size_left
=
self
.
max_past
,
)
# Decode
else
:
# kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
block_size
=
kv_cache
[
1
].
shape
[
3
]
vllm_attention_ops
.
single_query_cached_kv_attention
(
attn_output
,
query
,
kv_cache
[
0
],
kv_cache
[
1
],
self
.
kv_head_mapping
,
self
.
softmax_scale
,
block_tables
,
input_lengths
,
block_size
,
max_s
,
)
return
self
.
o_proj
(
attn_output
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
))
class
MistralMLP
(
nn
.
Module
):
def
__init__
(
self
,
prefix
,
config
,
weights
):
super
().
__init__
()
act
=
config
.
hidden_act
self
.
act
=
(
ACT2FN
[
act
]
if
"gelu"
not
in
act
else
lambda
x
:
torch
.
nn
.
functional
.
gelu
(
x
,
approximate
=
"tanh"
if
act
in
[
"gelu_fast"
,
"gelu_pytorch_tanh"
]
else
"none"
,
)
)
# Fuse gate and up proj
self
.
gate_up_proj
=
TensorParallelColumnLinear
.
load_multi
(
config
,
prefixes
=
[
f
"
{
prefix
}
.gate_proj"
,
f
"
{
prefix
}
.up_proj"
],
weights
=
weights
,
dim
=
0
,
bias
=
False
,
)
self
.
down_proj
=
TensorParallelRowLinear
.
load
(
config
,
prefix
=
f
"
{
prefix
}
.down_proj"
,
weights
=
weights
,
bias
=
False
,
)
self
.
intermediate_size
=
(
config
.
intermediate_size
//
weights
.
process_group
.
size
()
)
def
forward
(
self
,
hidden_states
):
gate_up_states
=
self
.
gate_up_proj
(
hidden_states
)
gate_up_states
=
gate_up_states
.
view
(
-
1
,
2
,
self
.
intermediate_size
)
return
self
.
down_proj
(
self
.
act
(
gate_up_states
[:,
0
])
*
gate_up_states
[:,
1
])
class
MistralLayer
(
nn
.
Module
):
def
__init__
(
self
,
layer_id
,
config
,
weights
):
super
().
__init__
()
prefix
=
f
"model.layers.
{
layer_id
}
"
self
.
self_attn
=
MistralAttention
(
prefix
=
f
"
{
prefix
}
.self_attn"
,
config
=
config
,
weights
=
weights
)
self
.
mlp
=
MistralMLP
(
prefix
=
f
"
{
prefix
}
.mlp"
,
config
=
config
,
weights
=
weights
)
self
.
input_layernorm
=
MistralRMSNorm
(
prefix
=
f
"
{
prefix
}
.input_layernorm"
,
weights
=
weights
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
MistralRMSNorm
(
prefix
=
f
"
{
prefix
}
.post_attention_layernorm"
,
weights
=
weights
,
eps
=
config
.
rms_norm_eps
,
)
def
forward
(
self
,
hidden_states
,
residual
,
cos
,
sin
,
cu_seqlen_prefill
,
kv_cache
,
block_tables
,
slots
,
input_lengths
,
max_s
,
prefill_cache_indices
,
):
normed_hidden_states
,
res
=
self
.
input_layernorm
(
hidden_states
,
residual
)
# Self Attention
attn_output
=
self
.
self_attn
(
normed_hidden_states
,
cos
,
sin
,
cu_seqlen_prefill
,
kv_cache
,
block_tables
,
slots
,
input_lengths
,
max_s
,
prefill_cache_indices
,
)
# faster post attention rms norm
normed_attn_res_output
,
attn_res
=
self
.
post_attention_layernorm
(
attn_output
,
res
)
mlp_output
=
self
.
mlp
(
normed_attn_res_output
)
return
mlp_output
,
attn_res
class
MistralModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
config
,
weights
):
super
().
__init__
()
process_group
=
weights
.
process_group
self
.
tp_rank
=
process_group
.
rank
()
self
.
tp_world_size
=
process_group
.
size
()
self
.
embed_tokens
=
TensorParallelEmbedding
(
prefix
=
"model.embed_tokens"
,
weights
=
weights
)
self
.
layers
=
nn
.
ModuleList
(
[
MistralLayer
(
layer_id
,
config
,
weights
,
)
for
layer_id
in
range
(
config
.
num_hidden_layers
)
]
)
self
.
norm
=
MistralRMSNorm
(
prefix
=
"model.norm"
,
weights
=
weights
,
eps
=
config
.
rms_norm_eps
)
self
.
gradient_checkpointing
=
False
self
.
head_size
=
self
.
layers
[
0
].
self_attn
.
head_size
self
.
num_heads
=
self
.
layers
[
0
].
self_attn
.
num_heads
self
.
num_key_value_heads
=
self
.
layers
[
0
].
self_attn
.
num_key_value_heads
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
cu_seqlen_prefill
:
Optional
[
torch
.
Tensor
],
kv_cache
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
block_tables
:
torch
.
Tensor
,
slots
:
torch
.
Tensor
,
input_lengths
:
torch
.
Tensor
,
max_s
:
int
,
prefill_cache_indices
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos
,
sin
=
self
.
layers
[
0
].
self_attn
.
rotary_emb
.
get_cos_sin
(
position_ids
,
max_s
,
hidden_states
.
dtype
)
residual
=
None
for
i
,
layer
in
enumerate
(
self
.
layers
):
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
,
cos
,
sin
,
cu_seqlen_prefill
,
kv_cache
[
i
],
block_tables
,
slots
,
input_lengths
,
max_s
,
prefill_cache_indices
,
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
class
FlashMistralForCausalLM
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
config
,
weights
):
super
().
__init__
()
self
.
model
=
MistralModel
(
config
,
weights
)
self
.
lm_head
=
TensorParallelHead
.
load
(
config
,
prefix
=
"lm_head"
,
weights
=
weights
,
)
self
.
max_past
=
config
.
sliding_window
if
self
.
max_past
is
None
:
raise
ValueError
(
"max_past cannot be None"
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
cu_seqlen_prefill
:
Optional
[
torch
.
Tensor
],
kv_cache
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
block_tables
:
torch
.
Tensor
,
slots
:
torch
.
Tensor
,
input_lengths
:
torch
.
Tensor
,
max_s
:
int
,
prefill_cache_indices
:
Optional
[
torch
.
Tensor
],
lm_head_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
prefill_cache_indices
is
not
None
:
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots
=
slots
[
prefill_cache_indices
]
else
:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
max_s
=
min
(
self
.
max_past
,
max_s
)
input_lengths
=
torch
.
clamp
(
input_lengths
,
max
=
self
.
max_past
)
hidden_states
=
self
.
model
(
input_ids
,
position_ids
,
cu_seqlen_prefill
,
kv_cache
,
block_tables
,
slots
,
input_lengths
,
max_s
,
prefill_cache_indices
,
)
if
lm_head_indices
is
not
None
:
hidden_states
=
hidden_states
[
lm_head_indices
]
logits
=
self
.
lm_head
(
hidden_states
)
return
logits
server/text_generation_server/models/flash_causal_lm.py
View file @
3b56d766
...
@@ -19,99 +19,17 @@ from text_generation_server.models.types import (
...
@@ -19,99 +19,17 @@ from text_generation_server.models.types import (
GeneratedText
,
GeneratedText
,
TopTokens
,
TopTokens
,
)
)
from
text_generation_server.models.cache_manager
import
(
get_cache_manager
,
set_cache_manager
,
BLOCK_SIZE
,
)
from
text_generation_server.pb
import
generate_pb2
from
text_generation_server.pb
import
generate_pb2
from
text_generation_server.utils
import
StoppingCriteria
,
HeterogeneousNextTokenChooser
from
text_generation_server.utils
import
StoppingCriteria
,
HeterogeneousNextTokenChooser
from
text_generation_server.utils.dist
import
MEMORY_FRACTION
from
text_generation_server.utils.dist
import
MEMORY_FRACTION
tracer
=
trace
.
get_tracer
(
__name__
)
tracer
=
trace
.
get_tracer
(
__name__
)
BLOCK_SIZE
=
16
# Will be set in warmup
CACHE_MANAGER
:
Optional
[
"CacheManager"
]
=
None
class
CacheManager
:
def
__init__
(
self
,
num_blocks
:
int
,
num_layers
:
int
,
num_heads
:
int
,
head_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
):
self
.
block_size
=
BLOCK_SIZE
self
.
num_blocks
=
num_blocks
element_size
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
x
=
self
.
block_size
//
element_size
self
.
kv_cache
=
[
(
torch
.
empty
(
(
num_blocks
,
num_heads
,
head_size
//
x
,
self
.
block_size
,
x
),
dtype
=
dtype
,
device
=
device
,
),
torch
.
empty
(
(
num_blocks
,
num_heads
,
head_size
,
self
.
block_size
),
dtype
=
dtype
,
device
=
device
,
),
)
for
_
in
range
(
num_layers
)
]
self
.
free_block_mask
=
torch
.
ones
(
num_blocks
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
self
.
slots
=
torch
.
arange
(
0
,
num_blocks
*
self
.
block_size
,
dtype
=
torch
.
int32
).
view
(
num_blocks
,
self
.
block_size
)
def
allocate
(
self
,
batch
:
"FlashCausalLMBatch"
):
# Get free blocks indices by finding values in mask that are not set to 0
free_block_indices
=
self
.
free_block_mask
.
nonzero
()
assert
(
len
(
free_block_indices
)
>=
batch
.
blocks
),
f
"Out of available cache blocks: asked
{
batch
.
blocks
}
, only
{
len
(
free_block_indices
)
}
free blocks"
# Slice by the number of required blocks
block_indices
=
free_block_indices
[:
batch
.
blocks
]
block_indices
=
block_indices
.
flatten
()
# Padded block tables
block_tables_tensor
=
torch
.
zeros
(
(
len
(
batch
),
batch
.
max_blocks
),
dtype
=
torch
.
int32
)
# Allocate paged attention blocks
cumulative_blocks
=
0
slots
=
[]
block_tables
=
[]
for
i
,
(
needed_blocks
,
needed_slots
)
in
enumerate
(
batch
.
needed_blocks_slots
):
# Get allocated blocks for this sequence
allocated_blocks
=
block_indices
[
cumulative_blocks
:
cumulative_blocks
+
needed_blocks
]
# Get slots for the allocated blocks
allocated_slots
=
self
.
slots
[
allocated_blocks
].
flatten
()[:
needed_slots
]
slots
.
append
(
allocated_slots
)
block_tables
.
append
(
allocated_blocks
.
tolist
())
block_tables_tensor
[
i
,
:
needed_blocks
]
=
allocated_blocks
cumulative_blocks
+=
needed_blocks
batch
.
needed_blocks_slots
=
None
batch
.
block_tables
=
block_tables
batch
.
block_tables_tensor
=
block_tables_tensor
.
to
(
batch
.
input_ids
.
device
)
batch
.
slots
=
torch
.
concat
(
slots
).
to
(
batch
.
input_ids
.
device
)
# Allocate the required number of blocks by setting the mask to 0
self
.
free_block_mask
[
block_indices
]
=
0
def
free
(
self
,
block_indices
:
Optional
[
List
[
int
]]):
if
block_indices
is
not
None
and
block_indices
:
# Reset mask
self
.
free_block_mask
[
block_indices
]
=
1
@
dataclass
@
dataclass
class
FlashCausalLMBatch
(
Batch
):
class
FlashCausalLMBatch
(
Batch
):
...
@@ -481,7 +399,6 @@ class FlashCausalLMBatch(Batch):
...
@@ -481,7 +399,6 @@ class FlashCausalLMBatch(Batch):
max_blocks
=
max
(
max_blocks
,
len
(
request_block_table
))
max_blocks
=
max
(
max_blocks
,
len
(
request_block_table
))
global
CACHE_MANAGER
block_indices_to_free
=
[]
block_indices_to_free
=
[]
# Iterate on all requests
# Iterate on all requests
for
i
,
r
in
enumerate
(
self
.
requests
):
for
i
,
r
in
enumerate
(
self
.
requests
):
...
@@ -489,7 +406,7 @@ class FlashCausalLMBatch(Batch):
...
@@ -489,7 +406,7 @@ class FlashCausalLMBatch(Batch):
if
r
.
id
not
in
requests_idx_mapping
.
keys
():
if
r
.
id
not
in
requests_idx_mapping
.
keys
():
block_indices_to_free
.
extend
(
self
.
block_tables
[
i
])
block_indices_to_free
.
extend
(
self
.
block_tables
[
i
])
# Free blocks
# Free blocks
CACHE_MANAGER
.
free
(
block_indices_to_free
)
get_cache_manager
()
.
free
(
block_indices_to_free
)
# Needed to avoid dropping blocks when the batches will go out of scope
# Needed to avoid dropping blocks when the batches will go out of scope
self
.
block_tables
=
None
self
.
block_tables
=
None
...
@@ -508,7 +425,7 @@ class FlashCausalLMBatch(Batch):
...
@@ -508,7 +425,7 @@ class FlashCausalLMBatch(Batch):
# Move to GPU now that we have the whole tensor
# Move to GPU now that we have the whole tensor
slot_indices
=
slot_indices
.
to
(
device
)
slot_indices
=
slot_indices
.
to
(
device
)
return
FlashCausalLMBatch
(
return
type
(
self
)
(
batch_id
=
self
.
batch_id
,
batch_id
=
self
.
batch_id
,
requests
=
requests
,
requests
=
requests
,
requests_idx_mapping
=
requests_idx_mapping
,
requests_idx_mapping
=
requests_idx_mapping
,
...
@@ -665,7 +582,7 @@ class FlashCausalLMBatch(Batch):
...
@@ -665,7 +582,7 @@ class FlashCausalLMBatch(Batch):
b
.
block_tables
=
None
b
.
block_tables
=
None
del
b
del
b
return
FlashCausalLMBatch
(
return
cls
(
batch_id
=
batches
[
0
].
batch_id
,
batch_id
=
batches
[
0
].
batch_id
,
requests
=
requests
,
requests
=
requests
,
requests_idx_mapping
=
requests_idx_mapping
,
requests_idx_mapping
=
requests_idx_mapping
,
...
@@ -698,9 +615,10 @@ class FlashCausalLMBatch(Batch):
...
@@ -698,9 +615,10 @@ class FlashCausalLMBatch(Batch):
def
__del__
(
self
):
def
__del__
(
self
):
if
self
.
block_tables
is
not
None
and
self
.
block_tables
:
if
self
.
block_tables
is
not
None
and
self
.
block_tables
:
global
CACHE_MANAGER
# Free blocks
# Free blocks
CACHE_MANAGER
.
free
(
list
(
itertools
.
chain
.
from_iterable
(
self
.
block_tables
)))
get_cache_manager
().
free
(
list
(
itertools
.
chain
.
from_iterable
(
self
.
block_tables
))
)
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
requests
)
return
len
(
self
.
requests
)
...
@@ -718,6 +636,7 @@ class FlashCausalLM(Model):
...
@@ -718,6 +636,7 @@ class FlashCausalLM(Model):
device
:
torch
.
device
,
device
:
torch
.
device
,
rank
:
int
=
0
,
rank
:
int
=
0
,
world_size
:
int
=
1
,
world_size
:
int
=
1
,
sliding_window
:
Optional
[
int
]
=
None
,
):
):
self
.
num_layers
=
num_layers
self
.
num_layers
=
num_layers
self
.
num_kv_heads
=
num_kv_heads
self
.
num_kv_heads
=
num_kv_heads
...
@@ -731,6 +650,7 @@ class FlashCausalLM(Model):
...
@@ -731,6 +650,7 @@ class FlashCausalLM(Model):
device
=
device
,
device
=
device
,
rank
=
rank
,
rank
=
rank
,
world_size
=
world_size
,
world_size
=
world_size
,
sliding_window
=
sliding_window
,
)
)
@
property
@
property
...
@@ -738,15 +658,14 @@ class FlashCausalLM(Model):
...
@@ -738,15 +658,14 @@ class FlashCausalLM(Model):
return
FlashCausalLMBatch
return
FlashCausalLMBatch
def
warmup
(
self
,
batch
:
FlashCausalLMBatch
):
def
warmup
(
self
,
batch
:
FlashCausalLMBatch
):
global
CACHE_MANAGER
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
try
:
try
:
CACHE_MANAGER
=
C
ache
M
anager
(
cache_manager
=
set_c
ache
_m
anager
(
batch
.
blocks
,
batch
.
blocks
,
self
.
num_layers
,
self
.
num_layers
,
self
.
num_kv_heads
,
self
.
num_kv_heads
,
self
.
head_size
,
self
.
head_size
,
self
.
sliding_window
is
not
None
,
self
.
dtype
,
self
.
dtype
,
self
.
device
,
self
.
device
,
)
)
...
@@ -775,48 +694,36 @@ class FlashCausalLM(Model):
...
@@ -775,48 +694,36 @@ class FlashCausalLM(Model):
num_blocks
=
(
num_blocks
=
(
int
(
free_memory
//
total_cache_size
)
int
(
free_memory
//
total_cache_size
)
# Add batch.blocks as we allocated it above, so it is included in the peak memory.
# Add batch.blocks as we allocated it above, so it is included in the peak memory.
+
CACHE_MANAGER
.
num_blocks
+
cache_manager
.
num_blocks
)
)
del
CACHE_MANAGER
del
batch
del
batch
torch
.
cuda
.
empty_cache
()
del
cache_manager
CACHE_MANAGER
=
C
ache
M
anager
(
set_c
ache
_m
anager
(
num_blocks
,
num_blocks
,
self
.
num_layers
,
self
.
num_layers
,
self
.
num_kv_heads
,
self
.
num_kv_heads
,
self
.
head_size
,
self
.
head_size
,
self
.
sliding_window
is
not
None
,
self
.
dtype
,
self
.
dtype
,
self
.
device
,
self
.
device
,
)
)
return
int
(
num_blocks
*
BLOCK_SIZE
)
return
int
(
num_blocks
*
BLOCK_SIZE
)
def
forward
(
def
forward
(
self
,
batch
:
FlashCausalLMBatch
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
cu_seqlen_prefill
:
Optional
[
torch
.
Tensor
],
block_tables
:
torch
.
Tensor
,
slots
:
torch
.
Tensor
,
input_lengths
:
torch
.
Tensor
,
max_s
:
int
,
lm_head_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
global
CACHE_MANAGER
# Model Forward
# Model Forward
return
self
.
model
.
forward
(
return
self
.
model
.
forward
(
input_ids
=
input_ids
,
input_ids
=
batch
.
input_ids
,
position_ids
=
position_ids
,
position_ids
=
batch
.
position_ids
,
cu_seqlen_prefill
=
cu_seqlen_prefill
,
cu_seqlen_prefill
=
batch
.
cu_seqlen_prefill
,
kv_cache
=
CACHE_MANAGER
.
kv_cache
,
kv_cache
=
get_cache_manager
()
.
kv_cache
,
block_tables
=
block_tables
,
block_tables
=
batch
.
block_tables
_tensor
,
slots
=
slots
,
slots
=
batch
.
slots
[
batch
.
slot_indices
]
,
input_lengths
=
input_lengths
,
input_lengths
=
batch
.
input_lengths
_tensor
,
max_s
=
max_s
,
max_s
=
batch
.
max_seqlen
,
lm_head_indices
=
l
m
_head_indices
,
lm_head_indices
=
batch
.
prefil
l_head_indices
,
)
)
@
tracer
.
start_as_current_span
(
"generate_token"
)
@
tracer
.
start_as_current_span
(
"generate_token"
)
...
@@ -828,19 +735,19 @@ class FlashCausalLM(Model):
...
@@ -828,19 +735,19 @@ class FlashCausalLM(Model):
if
batch
.
needed_blocks_slots
:
if
batch
.
needed_blocks_slots
:
# Allocate blocks to this batch
# Allocate blocks to this batch
CACHE_MANAGER
.
allocate
(
batch
)
block_tables
,
block_tables_tensor
,
slots
=
get_cache_manager
().
allocate
(
batch
.
needed_blocks_slots
,
batch
.
blocks
,
batch
.
max_blocks
,
batch
.
input_ids
.
device
,
)
batch
.
needed_blocks_slots
=
None
batch
.
block_tables
=
block_tables
batch
.
block_tables_tensor
=
block_tables_tensor
batch
.
slots
=
slots
try
:
try
:
out
=
self
.
forward
(
out
=
self
.
forward
(
batch
)
batch
.
input_ids
,
batch
.
position_ids
,
batch
.
cu_seqlen_prefill
,
batch
.
block_tables_tensor
,
batch
.
slots
[
batch
.
slot_indices
],
batch
.
input_lengths_tensor
,
batch
.
max_seqlen
,
batch
.
prefill_head_indices
,
)
except
Exception
as
e
:
except
Exception
as
e
:
del
batch
del
batch
raise
e
raise
e
...
...
server/text_generation_server/models/flash_mistral.py
0 → 100644
View file @
3b56d766
import
math
import
torch
import
torch.distributed
import
numpy
as
np
from
dataclasses
import
dataclass
from
opentelemetry
import
trace
from
transformers
import
PreTrainedTokenizerBase
from
transformers.models.llama
import
LlamaTokenizerFast
from
typing
import
Optional
,
Tuple
,
Type
from
text_generation_server.pb
import
generate_pb2
from
text_generation_server.models
import
FlashCausalLM
from
text_generation_server.models.flash_causal_lm
import
FlashCausalLMBatch
,
BLOCK_SIZE
from
text_generation_server.models.cache_manager
import
(
get_cache_manager
,
set_cache_manager
,
)
from
text_generation_server.models.custom_modeling.flash_mistral_modeling
import
(
FlashMistralForCausalLM
,
MistralConfig
,
)
from
text_generation_server.utils
import
(
initialize_torch_distributed
,
weight_files
,
Weights
,
HeterogeneousNextTokenChooser
,
StoppingCriteria
,
)
tracer
=
trace
.
get_tracer
(
__name__
)
# Will be set in init
SLIDING_WINDOW
:
Optional
[
int
]
=
None
SLIDING_WINDOW_BLOCKS
:
Optional
[
int
]
=
None
# Adds windowing logic to FlashCausalLMBatch
@
dataclass
class
FlashMistralBatch
(
FlashCausalLMBatch
):
# Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers
# as we only keep SLIDING_WINDOW values instead of the whole tensor
prefill_cache_indices
:
Optional
[
torch
.
Tensor
]
=
None
@
classmethod
def
from_pb
(
cls
,
pb
:
generate_pb2
.
Batch
,
tokenizer
:
PreTrainedTokenizerBase
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
)
->
"FlashCausalLMBatch"
:
global
SLIDING_WINDOW
global
SLIDING_WINDOW_BLOCKS
batch_inputs
=
[]
max_truncation
=
0
for
r
in
pb
.
requests
:
batch_inputs
.
append
(
r
.
inputs
)
max_truncation
=
max
(
max_truncation
,
r
.
truncate
)
batch_tokenized_inputs
=
tokenizer
(
batch_inputs
,
truncation
=
True
,
max_length
=
max_truncation
)[
"input_ids"
]
position_ids
=
[]
cu_seqlen_prefill
=
[
0
]
needed_blocks_slots
=
[]
start_slots
=
[]
slot_indices
=
[]
prefill_cache_indices
=
[]
input_lengths
=
[]
prefix_offsets
=
[]
read_offsets
=
[]
all_input_ids
=
[]
requests_idx_mapping
=
{}
all_prefill_logprobs
=
True
no_prefill_logprobs
=
True
prefill_head_indices
=
[]
prefill_next_token_indices
=
[]
prefill_cu_outlens
=
[
0
]
next_token_chooser_parameters
=
[]
stopping_criterias
=
[]
top_n_tokens
=
[]
# Cumulative length
cumulative_length
=
0
cumulative_max_length
=
0
prefill_out_cumulative_length
=
0
blocks
=
0
max_seqlen
=
0
max_length
=
0
max_blocks
=
0
# Parse batch
for
i
,
(
r
,
tokenized_input
)
in
enumerate
(
zip
(
pb
.
requests
,
batch_tokenized_inputs
)
):
# request id -> idx in list mapping
requests_idx_mapping
[
r
.
id
]
=
i
tokenized_input
=
tokenized_input
[
-
r
.
truncate
:]
input_length
=
len
(
tokenized_input
)
input_lengths
.
append
(
input_length
)
prefix_offsets
.
append
(
input_length
-
5
)
read_offsets
.
append
(
input_length
)
all_input_ids
.
append
(
tokenized_input
)
# Position ids
request_position_ids
=
torch
.
arange
(
0
,
input_length
,
dtype
=
torch
.
int32
)
position_ids
.
append
(
request_position_ids
)
# Add cumulative lengths of all previous inputs
cu_seqlen_prefill
.
append
(
cumulative_length
+
input_length
)
next_token_chooser_parameters
.
append
(
r
.
parameters
)
stopping_criteria
=
StoppingCriteria
.
from_pb
(
r
.
stopping_parameters
,
tokenizer
)
max_new_tokens
=
stopping_criteria
.
max_new_tokens
stopping_criterias
.
append
(
stopping_criteria
)
top_n_tokens
.
append
(
r
.
top_n_tokens
)
# Paged attention
# Remove one as the first token des not have a past
total_tokens
=
input_length
+
max_new_tokens
-
1
# Needed blocks can not go over SLIDING_WINDOW_BLOCKS
needed_blocks
=
min
(
math
.
ceil
(
total_tokens
/
BLOCK_SIZE
),
SLIDING_WINDOW_BLOCKS
)
blocks
+=
needed_blocks
needed_blocks_slots
.
append
((
needed_blocks
,
total_tokens
))
start_slots
.
append
(
cumulative_max_length
)
request_slot_indices
=
torch
.
arange
(
cumulative_max_length
,
cumulative_max_length
+
input_length
,
dtype
=
torch
.
int64
,
)
slot_indices
.
append
(
request_slot_indices
)
# Create tensor to slice into the kv tensor in prefill
request_prefill_cache_indices
=
torch
.
arange
(
cumulative_length
+
max
(
0
,
input_length
-
SLIDING_WINDOW
),
cumulative_length
+
input_length
,
dtype
=
torch
.
int64
,
)
prefill_cache_indices
.
append
(
request_prefill_cache_indices
)
all_prefill_logprobs
=
all_prefill_logprobs
and
r
.
prefill_logprobs
no_prefill_logprobs
=
no_prefill_logprobs
and
not
r
.
prefill_logprobs
if
r
.
prefill_logprobs
:
prefill_head_indices
.
append
(
request_position_ids
+
cumulative_length
)
prefill_next_token_indices
.
append
(
prefill_out_cumulative_length
+
input_length
-
1
)
prefill_cu_outlens
.
append
(
prefill_out_cumulative_length
+
input_length
)
prefill_out_cumulative_length
+=
input_length
else
:
prefill_head_indices
.
append
(
torch
.
tensor
(
[
cumulative_length
+
input_length
-
1
],
dtype
=
torch
.
int32
)
)
prefill_next_token_indices
.
append
(
prefill_out_cumulative_length
)
prefill_cu_outlens
.
append
(
prefill_out_cumulative_length
+
1
)
prefill_out_cumulative_length
+=
1
# Update
cumulative_length
+=
input_length
cumulative_max_length
+=
total_tokens
max_seqlen
=
max
(
max_seqlen
,
input_length
)
max_blocks
=
max
(
max_blocks
,
needed_blocks
)
max_length
=
max
(
max_length
,
input_length
+
max_new_tokens
)
next_token_chooser
=
HeterogeneousNextTokenChooser
.
from_pb
(
next_token_chooser_parameters
,
dtype
,
device
)
start_slots
=
torch
.
tensor
(
start_slots
,
dtype
=
torch
.
int64
)
# Padded all_input_ids_tensor
all_input_ids_tensor
=
np
.
zeros
(
(
len
(
all_input_ids
),
max_length
),
dtype
=
np
.
int64
)
for
i
,
input_ids
in
enumerate
(
all_input_ids
):
all_input_ids_tensor
[
i
,
:
len
(
input_ids
)]
=
input_ids
# Create tensors on device
all_input_ids_tensor
=
torch
.
tensor
(
all_input_ids_tensor
,
dtype
=
torch
.
int64
,
device
=
device
)
if
len
(
pb
.
requests
)
>
1
:
input_ids
=
np
.
concatenate
(
all_input_ids
,
dtype
=
np
.
int64
)
position_ids
=
torch
.
cat
(
position_ids
)
slot_indices
=
torch
.
cat
(
slot_indices
)
prefill_cache_indices
=
torch
.
cat
(
prefill_cache_indices
)
else
:
input_ids
=
all_input_ids
[
0
]
position_ids
=
position_ids
[
0
]
slot_indices
=
slot_indices
[
0
]
prefill_cache_indices
=
prefill_cache_indices
[
0
]
cu_seqlen_prefill
=
torch
.
tensor
(
cu_seqlen_prefill
,
device
=
device
,
dtype
=
torch
.
int32
)
position_ids
=
position_ids
.
to
(
device
)
slot_indices
=
slot_indices
.
to
(
device
)
prefill_cache_indices
=
prefill_cache_indices
.
to
(
device
)
input_ids
=
torch
.
tensor
(
input_ids
,
dtype
=
torch
.
int64
,
device
=
device
)
input_lengths_tensor
=
torch
.
tensor
(
input_lengths
,
dtype
=
torch
.
int32
,
device
=
device
)
if
all_prefill_logprobs
:
prefill_head_indices
=
None
prefill_next_token_indices
=
cu_seqlen_prefill
[
1
:]
-
1
elif
no_prefill_logprobs
:
prefill_head_indices
=
cu_seqlen_prefill
[
1
:]
-
1
prefill_next_token_indices
=
None
else
:
prefill_head_indices
=
torch
.
tensor
(
torch
.
cat
(
prefill_head_indices
),
dtype
=
torch
.
int64
,
device
=
device
)
prefill_next_token_indices
=
torch
.
tensor
(
prefill_next_token_indices
,
dtype
=
torch
.
int64
,
device
=
device
)
top_n_tokens_tensor
=
torch
.
tensor
(
top_n_tokens
,
device
=
device
,
dtype
=
torch
.
int64
)
return
cls
(
batch_id
=
pb
.
id
,
requests
=
pb
.
requests
,
requests_idx_mapping
=
requests_idx_mapping
,
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cu_seqlen_prefill
=
cu_seqlen_prefill
,
start_slots
=
start_slots
,
slot_indices
=
slot_indices
,
needed_blocks_slots
=
needed_blocks_slots
,
block_tables
=
None
,
block_tables_tensor
=
None
,
slots
=
None
,
max_seqlen
=
max_seqlen
,
prefill_head_indices
=
prefill_head_indices
,
prefill_next_token_indices
=
prefill_next_token_indices
,
prefill_cu_outlens
=
prefill_cu_outlens
,
input_lengths
=
input_lengths
,
input_lengths_tensor
=
input_lengths_tensor
,
prefix_offsets
=
prefix_offsets
,
read_offsets
=
read_offsets
,
all_input_ids
=
all_input_ids
,
all_input_ids_tensor
=
all_input_ids_tensor
,
next_token_chooser
=
next_token_chooser
,
stopping_criterias
=
stopping_criterias
,
top_n_tokens
=
top_n_tokens
,
top_n_tokens_tensor
=
top_n_tokens_tensor
,
blocks
=
blocks
,
max_blocks
=
max_blocks
,
prefill_cache_indices
=
prefill_cache_indices
,
)
class
FlashMistral
(
FlashCausalLM
):
def
__init__
(
self
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
Optional
[
str
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
trust_remote_code
:
bool
=
False
,
):
global
SLIDING_WINDOW
global
SLIDING_WINDOW_BLOCKS
self
.
process_group
,
rank
,
world_size
=
initialize_torch_distributed
()
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
dtype
=
torch
.
float16
if
dtype
is
None
else
dtype
else
:
raise
NotImplementedError
(
"FlashLlama is only available on GPU"
)
tokenizer
=
LlamaTokenizerFast
.
from_pretrained
(
model_id
,
revision
=
revision
,
padding_side
=
"left"
,
truncation_side
=
"left"
,
trust_remote_code
=
trust_remote_code
,
)
config
=
MistralConfig
.
from_pretrained
(
model_id
,
revision
=
revision
,
trust_remote_code
=
trust_remote_code
)
config
.
quantize
=
quantize
# Set context windows
SLIDING_WINDOW
=
config
.
sliding_window
SLIDING_WINDOW_BLOCKS
=
math
.
ceil
(
config
.
sliding_window
/
BLOCK_SIZE
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
filenames
=
weight_files
(
model_id
,
revision
=
revision
,
extension
=
".safetensors"
)
weights
=
Weights
(
filenames
,
device
,
dtype
,
process_group
=
self
.
process_group
)
if
config
.
quantize
in
[
"gptq"
,
"awq"
]:
weights
.
_set_gptq_params
(
model_id
)
model
=
FlashMistralForCausalLM
(
config
,
weights
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
super
(
FlashMistral
,
self
).
__init__
(
model
=
model
,
tokenizer
=
tokenizer
,
num_layers
=
len
(
model
.
model
.
layers
),
num_kv_heads
=
model
.
model
.
num_key_value_heads
,
head_size
=
model
.
model
.
head_size
,
dtype
=
dtype
,
device
=
device
,
rank
=
rank
,
world_size
=
world_size
,
sliding_window
=
config
.
sliding_window
,
)
@
property
def
batch_type
(
self
)
->
Type
[
FlashMistralBatch
]:
return
FlashMistralBatch
def
forward
(
self
,
batch
:
FlashMistralBatch
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Model Forward
logits
=
self
.
model
.
forward
(
input_ids
=
batch
.
input_ids
,
position_ids
=
batch
.
position_ids
,
cu_seqlen_prefill
=
batch
.
cu_seqlen_prefill
,
kv_cache
=
get_cache_manager
().
kv_cache
,
block_tables
=
batch
.
block_tables_tensor
,
slots
=
batch
.
slots
[
batch
.
slot_indices
],
input_lengths
=
batch
.
input_lengths_tensor
,
max_s
=
batch
.
max_seqlen
,
prefill_cache_indices
=
batch
.
prefill_cache_indices
,
lm_head_indices
=
batch
.
prefill_head_indices
,
)
if
batch
.
prefill_cache_indices
is
not
None
:
batch
.
prefill_cache_indices
=
None
return
logits
server/text_generation_server/models/model.py
View file @
3b56d766
...
@@ -21,6 +21,7 @@ class Model(ABC):
...
@@ -21,6 +21,7 @@ class Model(ABC):
device
:
torch
.
device
,
device
:
torch
.
device
,
rank
:
int
=
0
,
rank
:
int
=
0
,
world_size
:
int
=
1
,
world_size
:
int
=
1
,
sliding_window
:
Optional
[
int
]
=
None
,
):
):
self
.
model
=
model
.
eval
()
self
.
model
=
model
.
eval
()
self
.
tokenizer
=
tokenizer
self
.
tokenizer
=
tokenizer
...
@@ -30,6 +31,7 @@ class Model(ABC):
...
@@ -30,6 +31,7 @@ class Model(ABC):
self
.
device
=
device
self
.
device
=
device
self
.
rank
=
rank
self
.
rank
=
rank
self
.
world_size
=
world_size
self
.
world_size
=
world_size
self
.
sliding_window
=
sliding_window
self
.
has_position_ids
=
(
self
.
has_position_ids
=
(
inspect
.
signature
(
model
.
forward
).
parameters
.
get
(
"position_ids"
,
None
)
inspect
.
signature
(
model
.
forward
).
parameters
.
get
(
"position_ids"
,
None
)
...
@@ -40,10 +42,14 @@ class Model(ABC):
...
@@ -40,10 +42,14 @@ class Model(ABC):
@
property
@
property
def
info
(
self
)
->
InfoResponse
:
def
info
(
self
)
->
InfoResponse
:
if
self
.
requires_padding
and
self
.
sliding_window
is
not
None
:
raise
NotImplementedError
(
"sliding_window is not implemented with padding"
)
return
InfoResponse
(
return
InfoResponse
(
requires_padding
=
self
.
requires_padding
,
requires_padding
=
self
.
requires_padding
,
dtype
=
str
(
self
.
dtype
),
dtype
=
str
(
self
.
dtype
),
device_type
=
self
.
device
.
type
,
device_type
=
self
.
device
.
type
,
window_size
=
self
.
sliding_window
,
)
)
@
property
@
property
...
...
server/text_generation_server/utils/flash_attn.py
View file @
3b56d766
...
@@ -57,6 +57,7 @@ def attention(
...
@@ -57,6 +57,7 @@ def attention(
cu_seqlens
,
cu_seqlens
,
max_s
,
max_s
,
softmax_scale
,
softmax_scale
,
window_size_left
=-
1
,
):
):
if
HAS_FLASH_ATTN_V2
:
if
HAS_FLASH_ATTN_V2
:
return
flash_attn_2_cuda
.
varlen_fwd
(
return
flash_attn_2_cuda
.
varlen_fwd
(
...
@@ -72,11 +73,18 @@ def attention(
...
@@ -72,11 +73,18 @@ def attention(
softmax_scale
,
softmax_scale
,
False
,
False
,
True
,
True
,
window_size_left
,
0
,
False
,
False
,
None
,
None
,
)
)
if
HAS_FLASH_ATTN
:
if
HAS_FLASH_ATTN
:
if
window_size_left
!=
0
:
raise
NotImplementedError
(
"window_size_left is only available with flash attn v2"
)
# Flash attention v1 requires q, k and v to have the same number of heads
# Flash attention v1 requires q, k and v to have the same number of heads
if
k
.
shape
[
1
]
!=
q
.
shape
[
1
]:
if
k
.
shape
[
1
]
!=
q
.
shape
[
1
]:
# MQA expand
# MQA expand
...
...
server/text_generation_server/utils/layers.py
View file @
3b56d766
...
@@ -53,6 +53,7 @@ try:
...
@@ -53,6 +53,7 @@ try:
except
ImportError
:
except
ImportError
:
pass
pass
# Monkey patching
# Monkey patching
@
classmethod
@
classmethod
def
load_layer_norm
(
cls
,
prefix
,
weights
,
eps
):
def
load_layer_norm
(
cls
,
prefix
,
weights
,
eps
):
...
...
update_doc.py
View file @
3b56d766
...
@@ -8,7 +8,9 @@ def main():
...
@@ -8,7 +8,9 @@ def main():
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
output
=
subprocess
.
check_output
([
"text-generation-launcher"
,
"--help"
]).
decode
(
"utf-8"
)
output
=
subprocess
.
check_output
([
"text-generation-launcher"
,
"--help"
]).
decode
(
"utf-8"
)
final_doc
=
f
"# Text-generation-launcher arguments
\n
```
\n
{
output
}
\n
```"
final_doc
=
f
"# Text-generation-launcher arguments
\n
```
\n
{
output
}
\n
```"
filename
=
"docs/source/basic_tutorials/launcher.md"
filename
=
"docs/source/basic_tutorials/launcher.md"
...
@@ -16,16 +18,20 @@ def main():
...
@@ -16,16 +18,20 @@ def main():
with
open
(
filename
,
"r"
)
as
f
:
with
open
(
filename
,
"r"
)
as
f
:
doc
=
f
.
read
()
doc
=
f
.
read
()
if
doc
!=
final_doc
:
if
doc
!=
final_doc
:
tmp
=
"launcher.md"
tmp
=
"launcher.md"
with
open
(
tmp
,
"w"
)
as
g
:
with
open
(
tmp
,
"w"
)
as
g
:
g
.
write
(
final_doc
)
g
.
write
(
final_doc
)
diff
=
subprocess
.
run
([
"diff"
,
tmp
,
filename
],
capture_output
=
True
).
stdout
.
decode
(
"utf-8"
)
diff
=
subprocess
.
run
(
[
"diff"
,
tmp
,
filename
],
capture_output
=
True
).
stdout
.
decode
(
"utf-8"
)
print
(
diff
)
print
(
diff
)
raise
Exception
(
"Doc is not up-to-date, run `python update_doc.py` in order to update it"
)
raise
Exception
(
"Doc is not up-to-date, run `python update_doc.py` in order to update it"
)
else
:
else
:
with
open
(
filename
,
"w"
)
as
f
:
with
open
(
filename
,
"w"
)
as
f
:
f
.
write
(
final_doc
)
f
.
write
(
final_doc
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
main
()
main
()
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment