Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
5a1cf2f0
Commit
5a1cf2f0
authored
May 22, 2024
by
huangwb
Browse files
Merge tag 'v2.0.2' into dev-rocm
parents
24f58bb6
6073ece4
Changes
70
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
338 additions
and
126 deletions
+338
-126
server/text_generation_server/models/vlm_causal_lm.py
server/text_generation_server/models/vlm_causal_lm.py
+80
-36
server/text_generation_server/server.py
server/text_generation_server/server.py
+18
-5
server/text_generation_server/utils/dist.py
server/text_generation_server/utils/dist.py
+8
-1
server/text_generation_server/utils/flash_attn.py
server/text_generation_server/utils/flash_attn.py
+87
-53
server/text_generation_server/utils/import_utils.py
server/text_generation_server/utils/import_utils.py
+11
-0
server/text_generation_server/utils/layers.py
server/text_generation_server/utils/layers.py
+53
-5
server/text_generation_server/utils/logits_process.py
server/text_generation_server/utils/logits_process.py
+15
-6
server/text_generation_server/utils/paged_attention.py
server/text_generation_server/utils/paged_attention.py
+29
-5
server/text_generation_server/utils/tokens.py
server/text_generation_server/utils/tokens.py
+24
-9
server/text_generation_server/utils/weights.py
server/text_generation_server/utils/weights.py
+13
-6
No files found.
server/text_generation_server/models/vlm_causal_lm.py
View file @
5a1cf2f0
...
...
@@ -64,6 +64,46 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
return
height
//
patch_size
,
width
//
patch_size
def
image_text_replacement
(
image_input
,
config
,
image_id
)
->
str
:
if
config
.
model_type
==
"idefics2"
:
# TODO technically depends on image splitting which is not implemented.
num_features
=
320
return
(
"<fake_token_around_image>"
+
"<image>"
*
num_features
+
"<fake_token_around_image>"
)
elif
config
.
model_type
==
"llava_next"
:
height
,
width
=
image_input
[
"image_sizes"
][
image_id
]
num_features
=
get_number_of_features
(
height
,
width
,
config
)
from
loguru
import
logger
logger
.
info
(
f
"Found
{
num_features
}
in image of resolution
{
height
}
x
{
width
}
"
)
return
"<image>"
*
num_features
else
:
raise
RuntimeError
(
f
"Unknown config
{
config
.
model_type
}
for multimodal"
)
def
get_unpadded_features
(
height
:
int
,
width
:
int
,
npatches
:
int
,
num_patch_height
:
int
,
num_patch_width
:
int
)
->
Tuple
[
int
,
int
]:
current_height
=
npatches
*
num_patch_height
current_width
=
npatches
*
num_patch_width
aspect_ratio
:
float
=
width
/
height
current_aspect_ratio
:
float
=
current_width
/
current_height
if
aspect_ratio
>
current_aspect_ratio
:
new_height
=
(
height
*
current_width
)
//
width
current_height
=
new_height
else
:
new_width
=
(
width
*
current_height
)
//
height
current_width
=
new_width
unpadded_features
=
current_height
*
current_width
newline_features
=
current_height
return
(
unpadded_features
,
newline_features
)
def
get_number_of_features
(
height
:
int
,
width
:
int
,
config
)
->
int
:
# From config
# Hardcoded for CLIP for now
...
...
@@ -81,12 +121,9 @@ def get_number_of_features(height: int, width: int, config) -> int:
image_grid_pinpoints
,
image_size
,
)
height_of_patch
=
math
.
ceil
(
height
/
width
*
npatches
)
unpadded_features
=
npatches
*
height_of_patch
*
num_patch_height
*
num_patch_width
# They are only added after width
newline_features
=
height_of_patch
*
num_patch_width
unpadded_features
,
newline_features
=
get_unpadded_features
(
height
,
width
,
npatches
,
num_patch_height
,
num_patch_width
)
# The base patch covers the entire image
base_features
=
npatches
**
2
return
unpadded_features
+
newline_features
+
base_features
...
...
@@ -99,12 +136,9 @@ def load_data_uri(image_uri: str) -> Image.Image:
return
image
# assert get_number_of_features(889, 1024) == 2634, f"{get_number_of_features(889, 1024)}"
# assert get_number_of_features(640, 640) == 2928
class
VlmCausalLMBatch
(
FlashMistralBatch
):
pixel_values
:
Optional
[
List
[
torch
.
Tensor
]]
pixel_attention_mask
:
Optional
[
List
[
torch
.
Tensor
]]
image_sizes
:
Optional
[
List
[
Tuple
[
int
,
int
]]]
@
classmethod
...
...
@@ -112,6 +146,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
def
concatenate
(
cls
,
batches
):
batch
=
super
(
VlmCausalLMBatch
,
cls
).
concatenate
(
batches
)
batch
.
pixel_values
=
None
batch
.
pixel_attention_mask
=
None
batch
.
image_sizes
=
None
return
batch
...
...
@@ -119,6 +154,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
def
filter
(
self
,
request_ids
:
List
[
int
]):
batch
=
super
().
filter
(
request_ids
)
batch
.
pixel_values
=
None
batch
.
pixel_attention_mask
=
None
batch
.
image_sizes
=
None
return
batch
...
...
@@ -130,6 +166,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
for
r
in
requests
:
chunks
=
split
(
r
.
inputs
)
full_text
=
""
image_id
=
0
for
chunk
in
chunks
:
if
chunk
[
"type"
]
==
"text"
:
full_text
+=
chunk
[
"content"
]
...
...
@@ -147,9 +184,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
"Cannot process input image not starting with data:"
)
image_input
=
processor
.
image_processor
(
image
,
return_tensors
=
"pt"
)
height
,
width
=
image_input
[
"image_sizes"
][
0
]
num_features
=
get_number_of_features
(
height
,
width
,
config
)
full_text
+=
"<image>"
*
num_features
full_text
+=
image_text_replacement
(
image_input
,
config
,
image_id
)
image_inputs
.
append
(
image_input
)
else
:
raise
RuntimeError
(
f
"Invalid chunk type
{
chunk
[
'type'
]
}
"
)
...
...
@@ -161,12 +196,21 @@ class VlmCausalLMBatch(FlashMistralBatch):
batch_inputs
,
truncation
=
True
,
max_length
=
max_truncation
)[
"input_ids"
]
if
image_inputs
:
image_inputs
=
{
image_input
=
image_inputs
[
0
]
new_image_inputs
=
{
"pixel_values"
:
torch
.
cat
(
[
img
[
"pixel_values"
]
for
img
in
image_inputs
],
dim
=
0
),
"image_sizes"
:
torch
.
cat
([
img
[
"image_sizes"
]
for
img
in
image_inputs
]),
}
if
"pixel_attention_mask"
in
image_input
:
new_image_inputs
[
"pixel_attention_mask"
]
=
torch
.
cat
(
[
img
[
"pixel_attention_mask"
]
for
img
in
image_inputs
],
dim
=
0
)
if
"image_sizes"
in
image_input
:
new_image_inputs
[
"image_sizes"
]
=
torch
.
cat
(
[
img
[
"image_sizes"
]
for
img
in
image_inputs
],
dim
=
0
)
image_inputs
=
new_image_inputs
else
:
image_inputs
=
None
return
batch_tokenized_inputs
,
image_inputs
...
...
@@ -187,9 +231,19 @@ class VlmCausalLMBatch(FlashMistralBatch):
batch
=
cls
.
from_tokenized
(
pb
,
tokenizer
,
batch_tokenized_inputs
,
dtype
,
device
)
if
image_inputs
is
not
None
:
batch
.
pixel_values
=
image_inputs
[
"pixel_values"
].
to
(
device
=
device
)
if
"pixel_attention_mask"
in
image_inputs
:
batch
.
pixel_attention_mask
=
image_inputs
[
"pixel_attention_mask"
].
to
(
device
=
device
)
else
:
batch
.
pixel_attention_mask
=
None
if
"image_sizes"
in
image_inputs
:
batch
.
image_sizes
=
image_inputs
[
"image_sizes"
].
to
(
device
=
device
)
else
:
batch
.
image_sizes
=
None
else
:
batch
.
pixel_values
=
None
batch
.
pixel_attention_mask
=
None
batch
.
image_sizes
=
None
return
batch
...
...
@@ -199,16 +253,6 @@ class VlmCausalLM(BaseFlashMistral):
def
batch_type
(
self
)
->
Type
[
VlmCausalLMBatch
]:
return
VlmCausalLMBatch
def
get_layer_config
(
self
,
model
)
->
Tuple
[
int
,
int
,
int
]:
return
(
len
(
model
.
language_model
.
model
.
layers
),
model
.
language_model
.
model
.
num_key_value_heads
,
model
.
language_model
.
model
.
head_size
,
)
def
max_past
(
self
)
->
Optional
[
int
]:
return
getattr
(
self
.
model
.
language_model
,
"max_past"
,
None
)
def
forward
(
self
,
batch
:
VlmCausalLMBatch
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
...
...
@@ -270,17 +314,14 @@ class VlmCausalLM(BaseFlashMistral):
max_s
=
min
(
self
.
max_past
(),
max_s
)
bs
=
input_ids
.
shape
[
0
]
padded_bs
=
bs
if
bs
==
3
:
padded_bs
=
4
elif
3
<
bs
<=
8
:
padded_bs
=
8
elif
bs
>
8
:
padded_bs
=
(
bs
+
7
)
//
8
*
8
# Try to find an associated cuda graph
cuda_graph
=
self
.
cuda_graphs
.
get
(
padded_bs
,
None
)
bs
=
input_ids
.
shape
[
0
]
sorted_padded_bs
=
sorted
([
k
for
k
in
self
.
cuda_graphs
.
keys
()
if
k
>=
bs
])
if
sorted_padded_bs
:
# Get associated cuda graph
cuda_graph
=
self
.
cuda_graphs
[
sorted_padded_bs
[
0
]]
else
:
cuda_graph
=
None
if
cu_seqlen_prefill
is
not
None
or
cuda_graph
is
None
:
logits
,
speculative_logits
=
self
.
model
.
forward
(
input_ids
=
input_ids
,
...
...
@@ -294,12 +335,15 @@ class VlmCausalLM(BaseFlashMistral):
prefill_cache_indices
=
batch
.
prefill_cache_indices
,
lm_head_indices
=
lm_head_indices
,
pixel_values
=
batch
.
pixel_values
,
pixel_attention_mask
=
batch
.
pixel_attention_mask
,
image_sizes
=
batch
.
image_sizes
,
)
if
batch
.
prefill_cache_indices
is
not
None
:
batch
.
prefill_cache_indices
=
None
if
batch
.
pixel_values
is
not
None
:
batch
.
pixel_values
=
None
if
batch
.
pixel_attention_mask
is
not
None
:
batch
.
pixel_attention_mask
=
None
if
batch
.
image_sizes
is
not
None
:
batch
.
image_sizes
=
None
return
logits
,
speculative_logits
...
...
server/text_generation_server/server.py
View file @
5a1cf2f0
...
...
@@ -2,6 +2,7 @@ import asyncio
import
os
import
torch
import
time
import
signal
from
grpc
import
aio
from
loguru
import
logger
...
...
@@ -19,6 +20,21 @@ from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
from
text_generation_server.models.idefics_causal_lm
import
IdeficsCausalLMBatch
class
SignalHandler
:
KEEP_PROCESSING
=
True
def
__init__
(
self
):
signal
.
signal
(
signal
.
SIGINT
,
self
.
exit_gracefully
)
signal
.
signal
(
signal
.
SIGTERM
,
self
.
exit_gracefully
)
def
exit_gracefully
(
self
,
signum
,
frame
):
print
(
f
"Exiting gracefully: Signal
{
signum
}
"
)
self
.
KEEP_PROCESSING
=
False
signal_handler
=
SignalHandler
()
class
TextGenerationService
(
generate_pb2_grpc
.
TextGenerationServiceServicer
):
def
__init__
(
self
,
...
...
@@ -231,11 +247,8 @@ def serve(
logger
.
info
(
"Server started at {}"
.
format
(
local_url
))
try
:
await
server
.
wait_for_termination
()
except
KeyboardInterrupt
:
logger
.
info
(
"Signal received. Shutting down"
)
await
server
.
stop
(
0
)
while
signal_handler
.
KEEP_PROCESSING
:
await
asyncio
.
sleep
(
0.5
)
asyncio
.
run
(
serve_inner
(
...
...
server/text_generation_server/utils/dist.py
View file @
5a1cf2f0
...
...
@@ -57,6 +57,13 @@ def initialize_torch_distributed():
options
.
is_high_priority_stream
=
True
options
.
_timeout
=
timedelta
(
seconds
=
60
)
else
:
try
:
import
oneccl_bindings_for_pytorch
backend
=
"ccl"
if
os
.
getenv
(
"CCL_WORKER_COUNT"
,
None
)
is
None
:
os
.
environ
[
"CCL_WORKER_COUNT"
]
=
str
(
1
)
except
ImportError
:
backend
=
"gloo"
options
=
None
...
...
server/text_generation_server/utils/flash_attn.py
View file @
5a1cf2f0
This diff is collapsed.
Click to expand it.
server/text_generation_server/utils/import_utils.py
View file @
5a1cf2f0
import
torch
def
is_xpu_available
():
try
:
import
intel_extension_for_pytorch
except
ImportError
:
return
False
return
hasattr
(
torch
,
"xpu"
)
and
torch
.
xpu
.
is_available
()
IS_ROCM_SYSTEM
=
torch
.
version
.
hip
is
not
None
IS_CUDA_SYSTEM
=
torch
.
version
.
cuda
is
not
None
IS_XPU_SYSTEM
=
is_xpu_available
()
server/text_generation_server/utils/layers.py
View file @
5a1cf2f0
This diff is collapsed.
Click to expand it.
server/text_generation_server/utils/logits_process.py
View file @
5a1cf2f0
...
...
@@ -143,13 +143,16 @@ class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
score
=
torch
.
gather
(
scores
,
1
,
input_ids
)
# if score < 0 then penalty has to be multiplied to reduce the previous token probability
score
=
-
torch
.
where
(
score
<
0
,
score
*
self
.
penalty
,
score
/
self
.
penalty
)
# set score to 0 where input_ids is a padding token
score
*=
input_ids
.
ne
(
0
)
return
scores
.
scatter_add_
(
1
,
input_ids
,
score
)
class
HeterogeneousFrequencyPenaltyLogitsProcessor
(
LogitsProcessor
):
r
"""
Frequency penalty as defined by OpenAI
Frequency penalty as defined by OpenAI in
https://platform.openai.com/docs/guides/text-generation/parameter-details
Args:
frequency_penalty (`List[float]`):
...
...
@@ -163,13 +166,19 @@ class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
).
unsqueeze
(
1
)
def
__call__
(
self
,
input_ids
:
torch
.
Tensor
,
scores
:
torch
.
Tensor
)
->
torch
.
Tensor
:
score
=
torch
.
gather
(
scores
,
1
,
input_ids
)
# if score < 0 then penalty has to be multiplied to reduce the previous token probability
score
=
-
torch
.
where
(
score
<
0
,
score
*
self
.
penalty_tensor
,
score
/
self
.
penalty_tensor
batch_size
,
input_size
=
input_ids
.
size
()
vocab_size
=
scores
.
size
(
1
)
# Calculate the frequency for each token so far
token_freq
=
torch
.
zeros
(
batch_size
,
vocab_size
,
device
=
input_ids
.
device
)
token_freq
.
scatter_add_
(
1
,
input_ids
,
torch
.
ones_like
(
input_ids
,
dtype
=
torch
.
float
)
)
token_freq
/=
input_size
return
scores
.
scatter_add_
(
1
,
input_ids
,
score
)
# Apply the frequency penalty to logits
scores
-=
token_freq
*
self
.
penalty_tensor
return
scores
def
filter
(
self
,
indices
):
self
.
penalty
=
[
self
.
penalty
[
i
]
for
i
in
indices
]
...
...
server/text_generation_server/utils/paged_attention.py
View file @
5a1cf2f0
This diff is collapsed.
Click to expand it.
server/text_generation_server/utils/tokens.py
View file @
5a1cf2f0
This diff is collapsed.
Click to expand it.
server/text_generation_server/utils/weights.py
View file @
5a1cf2f0
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment