Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
5a1cf2f0
Commit
5a1cf2f0
authored
May 22, 2024
by
huangwb
Browse files
Merge tag 'v2.0.2' into dev-rocm
parents
24f58bb6
6073ece4
Changes
70
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
338 additions
and
126 deletions
+338
-126
server/text_generation_server/models/vlm_causal_lm.py
server/text_generation_server/models/vlm_causal_lm.py
+80
-36
server/text_generation_server/server.py
server/text_generation_server/server.py
+18
-5
server/text_generation_server/utils/dist.py
server/text_generation_server/utils/dist.py
+8
-1
server/text_generation_server/utils/flash_attn.py
server/text_generation_server/utils/flash_attn.py
+87
-53
server/text_generation_server/utils/import_utils.py
server/text_generation_server/utils/import_utils.py
+11
-0
server/text_generation_server/utils/layers.py
server/text_generation_server/utils/layers.py
+53
-5
server/text_generation_server/utils/logits_process.py
server/text_generation_server/utils/logits_process.py
+15
-6
server/text_generation_server/utils/paged_attention.py
server/text_generation_server/utils/paged_attention.py
+29
-5
server/text_generation_server/utils/tokens.py
server/text_generation_server/utils/tokens.py
+24
-9
server/text_generation_server/utils/weights.py
server/text_generation_server/utils/weights.py
+13
-6
No files found.
server/text_generation_server/models/vlm_causal_lm.py
View file @
5a1cf2f0
...
@@ -64,6 +64,46 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
...
@@ -64,6 +64,46 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
return
height
//
patch_size
,
width
//
patch_size
return
height
//
patch_size
,
width
//
patch_size
def
image_text_replacement
(
image_input
,
config
,
image_id
)
->
str
:
if
config
.
model_type
==
"idefics2"
:
# TODO technically depends on image splitting which is not implemented.
num_features
=
320
return
(
"<fake_token_around_image>"
+
"<image>"
*
num_features
+
"<fake_token_around_image>"
)
elif
config
.
model_type
==
"llava_next"
:
height
,
width
=
image_input
[
"image_sizes"
][
image_id
]
num_features
=
get_number_of_features
(
height
,
width
,
config
)
from
loguru
import
logger
logger
.
info
(
f
"Found
{
num_features
}
in image of resolution
{
height
}
x
{
width
}
"
)
return
"<image>"
*
num_features
else
:
raise
RuntimeError
(
f
"Unknown config
{
config
.
model_type
}
for multimodal"
)
def
get_unpadded_features
(
height
:
int
,
width
:
int
,
npatches
:
int
,
num_patch_height
:
int
,
num_patch_width
:
int
)
->
Tuple
[
int
,
int
]:
current_height
=
npatches
*
num_patch_height
current_width
=
npatches
*
num_patch_width
aspect_ratio
:
float
=
width
/
height
current_aspect_ratio
:
float
=
current_width
/
current_height
if
aspect_ratio
>
current_aspect_ratio
:
new_height
=
(
height
*
current_width
)
//
width
current_height
=
new_height
else
:
new_width
=
(
width
*
current_height
)
//
height
current_width
=
new_width
unpadded_features
=
current_height
*
current_width
newline_features
=
current_height
return
(
unpadded_features
,
newline_features
)
def
get_number_of_features
(
height
:
int
,
width
:
int
,
config
)
->
int
:
def
get_number_of_features
(
height
:
int
,
width
:
int
,
config
)
->
int
:
# From config
# From config
# Hardcoded for CLIP for now
# Hardcoded for CLIP for now
...
@@ -81,12 +121,9 @@ def get_number_of_features(height: int, width: int, config) -> int:
...
@@ -81,12 +121,9 @@ def get_number_of_features(height: int, width: int, config) -> int:
image_grid_pinpoints
,
image_grid_pinpoints
,
image_size
,
image_size
,
)
)
unpadded_features
,
newline_features
=
get_unpadded_features
(
height_of_patch
=
math
.
ceil
(
height
/
width
*
npatches
)
height
,
width
,
npatches
,
num_patch_height
,
num_patch_width
)
unpadded_features
=
npatches
*
height_of_patch
*
num_patch_height
*
num_patch_width
# They are only added after width
newline_features
=
height_of_patch
*
num_patch_width
# The base patch covers the entire image
# The base patch covers the entire image
base_features
=
npatches
**
2
base_features
=
npatches
**
2
return
unpadded_features
+
newline_features
+
base_features
return
unpadded_features
+
newline_features
+
base_features
...
@@ -99,12 +136,9 @@ def load_data_uri(image_uri: str) -> Image.Image:
...
@@ -99,12 +136,9 @@ def load_data_uri(image_uri: str) -> Image.Image:
return
image
return
image
# assert get_number_of_features(889, 1024) == 2634, f"{get_number_of_features(889, 1024)}"
# assert get_number_of_features(640, 640) == 2928
class
VlmCausalLMBatch
(
FlashMistralBatch
):
class
VlmCausalLMBatch
(
FlashMistralBatch
):
pixel_values
:
Optional
[
List
[
torch
.
Tensor
]]
pixel_values
:
Optional
[
List
[
torch
.
Tensor
]]
pixel_attention_mask
:
Optional
[
List
[
torch
.
Tensor
]]
image_sizes
:
Optional
[
List
[
Tuple
[
int
,
int
]]]
image_sizes
:
Optional
[
List
[
Tuple
[
int
,
int
]]]
@
classmethod
@
classmethod
...
@@ -112,6 +146,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
...
@@ -112,6 +146,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
def
concatenate
(
cls
,
batches
):
def
concatenate
(
cls
,
batches
):
batch
=
super
(
VlmCausalLMBatch
,
cls
).
concatenate
(
batches
)
batch
=
super
(
VlmCausalLMBatch
,
cls
).
concatenate
(
batches
)
batch
.
pixel_values
=
None
batch
.
pixel_values
=
None
batch
.
pixel_attention_mask
=
None
batch
.
image_sizes
=
None
batch
.
image_sizes
=
None
return
batch
return
batch
...
@@ -119,6 +154,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
...
@@ -119,6 +154,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
def
filter
(
self
,
request_ids
:
List
[
int
]):
def
filter
(
self
,
request_ids
:
List
[
int
]):
batch
=
super
().
filter
(
request_ids
)
batch
=
super
().
filter
(
request_ids
)
batch
.
pixel_values
=
None
batch
.
pixel_values
=
None
batch
.
pixel_attention_mask
=
None
batch
.
image_sizes
=
None
batch
.
image_sizes
=
None
return
batch
return
batch
...
@@ -130,6 +166,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
...
@@ -130,6 +166,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
for
r
in
requests
:
for
r
in
requests
:
chunks
=
split
(
r
.
inputs
)
chunks
=
split
(
r
.
inputs
)
full_text
=
""
full_text
=
""
image_id
=
0
for
chunk
in
chunks
:
for
chunk
in
chunks
:
if
chunk
[
"type"
]
==
"text"
:
if
chunk
[
"type"
]
==
"text"
:
full_text
+=
chunk
[
"content"
]
full_text
+=
chunk
[
"content"
]
...
@@ -147,9 +184,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
...
@@ -147,9 +184,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
"Cannot process input image not starting with data:"
"Cannot process input image not starting with data:"
)
)
image_input
=
processor
.
image_processor
(
image
,
return_tensors
=
"pt"
)
image_input
=
processor
.
image_processor
(
image
,
return_tensors
=
"pt"
)
height
,
width
=
image_input
[
"image_sizes"
][
0
]
full_text
+=
image_text_replacement
(
image_input
,
config
,
image_id
)
num_features
=
get_number_of_features
(
height
,
width
,
config
)
full_text
+=
"<image>"
*
num_features
image_inputs
.
append
(
image_input
)
image_inputs
.
append
(
image_input
)
else
:
else
:
raise
RuntimeError
(
f
"Invalid chunk type
{
chunk
[
'type'
]
}
"
)
raise
RuntimeError
(
f
"Invalid chunk type
{
chunk
[
'type'
]
}
"
)
...
@@ -161,12 +196,21 @@ class VlmCausalLMBatch(FlashMistralBatch):
...
@@ -161,12 +196,21 @@ class VlmCausalLMBatch(FlashMistralBatch):
batch_inputs
,
truncation
=
True
,
max_length
=
max_truncation
batch_inputs
,
truncation
=
True
,
max_length
=
max_truncation
)[
"input_ids"
]
)[
"input_ids"
]
if
image_inputs
:
if
image_inputs
:
image_inputs
=
{
image_input
=
image_inputs
[
0
]
new_image_inputs
=
{
"pixel_values"
:
torch
.
cat
(
"pixel_values"
:
torch
.
cat
(
[
img
[
"pixel_values"
]
for
img
in
image_inputs
],
dim
=
0
[
img
[
"pixel_values"
]
for
img
in
image_inputs
],
dim
=
0
),
),
"image_sizes"
:
torch
.
cat
([
img
[
"image_sizes"
]
for
img
in
image_inputs
]),
}
}
if
"pixel_attention_mask"
in
image_input
:
new_image_inputs
[
"pixel_attention_mask"
]
=
torch
.
cat
(
[
img
[
"pixel_attention_mask"
]
for
img
in
image_inputs
],
dim
=
0
)
if
"image_sizes"
in
image_input
:
new_image_inputs
[
"image_sizes"
]
=
torch
.
cat
(
[
img
[
"image_sizes"
]
for
img
in
image_inputs
],
dim
=
0
)
image_inputs
=
new_image_inputs
else
:
else
:
image_inputs
=
None
image_inputs
=
None
return
batch_tokenized_inputs
,
image_inputs
return
batch_tokenized_inputs
,
image_inputs
...
@@ -187,9 +231,19 @@ class VlmCausalLMBatch(FlashMistralBatch):
...
@@ -187,9 +231,19 @@ class VlmCausalLMBatch(FlashMistralBatch):
batch
=
cls
.
from_tokenized
(
pb
,
tokenizer
,
batch_tokenized_inputs
,
dtype
,
device
)
batch
=
cls
.
from_tokenized
(
pb
,
tokenizer
,
batch_tokenized_inputs
,
dtype
,
device
)
if
image_inputs
is
not
None
:
if
image_inputs
is
not
None
:
batch
.
pixel_values
=
image_inputs
[
"pixel_values"
].
to
(
device
=
device
)
batch
.
pixel_values
=
image_inputs
[
"pixel_values"
].
to
(
device
=
device
)
if
"pixel_attention_mask"
in
image_inputs
:
batch
.
pixel_attention_mask
=
image_inputs
[
"pixel_attention_mask"
].
to
(
device
=
device
)
else
:
batch
.
pixel_attention_mask
=
None
if
"image_sizes"
in
image_inputs
:
batch
.
image_sizes
=
image_inputs
[
"image_sizes"
].
to
(
device
=
device
)
batch
.
image_sizes
=
image_inputs
[
"image_sizes"
].
to
(
device
=
device
)
else
:
batch
.
image_sizes
=
None
else
:
else
:
batch
.
pixel_values
=
None
batch
.
pixel_values
=
None
batch
.
pixel_attention_mask
=
None
batch
.
image_sizes
=
None
batch
.
image_sizes
=
None
return
batch
return
batch
...
@@ -199,16 +253,6 @@ class VlmCausalLM(BaseFlashMistral):
...
@@ -199,16 +253,6 @@ class VlmCausalLM(BaseFlashMistral):
def
batch_type
(
self
)
->
Type
[
VlmCausalLMBatch
]:
def
batch_type
(
self
)
->
Type
[
VlmCausalLMBatch
]:
return
VlmCausalLMBatch
return
VlmCausalLMBatch
def
get_layer_config
(
self
,
model
)
->
Tuple
[
int
,
int
,
int
]:
return
(
len
(
model
.
language_model
.
model
.
layers
),
model
.
language_model
.
model
.
num_key_value_heads
,
model
.
language_model
.
model
.
head_size
,
)
def
max_past
(
self
)
->
Optional
[
int
]:
return
getattr
(
self
.
model
.
language_model
,
"max_past"
,
None
)
def
forward
(
def
forward
(
self
,
batch
:
VlmCausalLMBatch
self
,
batch
:
VlmCausalLMBatch
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
...
@@ -270,17 +314,14 @@ class VlmCausalLM(BaseFlashMistral):
...
@@ -270,17 +314,14 @@ class VlmCausalLM(BaseFlashMistral):
max_s
=
min
(
self
.
max_past
(),
max_s
)
max_s
=
min
(
self
.
max_past
(),
max_s
)
bs
=
input_ids
.
shape
[
0
]
bs
=
input_ids
.
shape
[
0
]
padded_bs
=
bs
if
bs
==
3
:
padded_bs
=
4
elif
3
<
bs
<=
8
:
padded_bs
=
8
elif
bs
>
8
:
padded_bs
=
(
bs
+
7
)
//
8
*
8
# Try to find an associated cuda graph
# Try to find an associated cuda graph
cuda_graph
=
self
.
cuda_graphs
.
get
(
padded_bs
,
None
)
bs
=
input_ids
.
shape
[
0
]
sorted_padded_bs
=
sorted
([
k
for
k
in
self
.
cuda_graphs
.
keys
()
if
k
>=
bs
])
if
sorted_padded_bs
:
# Get associated cuda graph
cuda_graph
=
self
.
cuda_graphs
[
sorted_padded_bs
[
0
]]
else
:
cuda_graph
=
None
if
cu_seqlen_prefill
is
not
None
or
cuda_graph
is
None
:
if
cu_seqlen_prefill
is
not
None
or
cuda_graph
is
None
:
logits
,
speculative_logits
=
self
.
model
.
forward
(
logits
,
speculative_logits
=
self
.
model
.
forward
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
...
@@ -294,12 +335,15 @@ class VlmCausalLM(BaseFlashMistral):
...
@@ -294,12 +335,15 @@ class VlmCausalLM(BaseFlashMistral):
prefill_cache_indices
=
batch
.
prefill_cache_indices
,
prefill_cache_indices
=
batch
.
prefill_cache_indices
,
lm_head_indices
=
lm_head_indices
,
lm_head_indices
=
lm_head_indices
,
pixel_values
=
batch
.
pixel_values
,
pixel_values
=
batch
.
pixel_values
,
pixel_attention_mask
=
batch
.
pixel_attention_mask
,
image_sizes
=
batch
.
image_sizes
,
image_sizes
=
batch
.
image_sizes
,
)
)
if
batch
.
prefill_cache_indices
is
not
None
:
if
batch
.
prefill_cache_indices
is
not
None
:
batch
.
prefill_cache_indices
=
None
batch
.
prefill_cache_indices
=
None
if
batch
.
pixel_values
is
not
None
:
if
batch
.
pixel_values
is
not
None
:
batch
.
pixel_values
=
None
batch
.
pixel_values
=
None
if
batch
.
pixel_attention_mask
is
not
None
:
batch
.
pixel_attention_mask
=
None
if
batch
.
image_sizes
is
not
None
:
if
batch
.
image_sizes
is
not
None
:
batch
.
image_sizes
=
None
batch
.
image_sizes
=
None
return
logits
,
speculative_logits
return
logits
,
speculative_logits
...
...
server/text_generation_server/server.py
View file @
5a1cf2f0
...
@@ -2,6 +2,7 @@ import asyncio
...
@@ -2,6 +2,7 @@ import asyncio
import
os
import
os
import
torch
import
torch
import
time
import
time
import
signal
from
grpc
import
aio
from
grpc
import
aio
from
loguru
import
logger
from
loguru
import
logger
...
@@ -19,6 +20,21 @@ from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
...
@@ -19,6 +20,21 @@ from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
from
text_generation_server.models.idefics_causal_lm
import
IdeficsCausalLMBatch
from
text_generation_server.models.idefics_causal_lm
import
IdeficsCausalLMBatch
class
SignalHandler
:
KEEP_PROCESSING
=
True
def
__init__
(
self
):
signal
.
signal
(
signal
.
SIGINT
,
self
.
exit_gracefully
)
signal
.
signal
(
signal
.
SIGTERM
,
self
.
exit_gracefully
)
def
exit_gracefully
(
self
,
signum
,
frame
):
print
(
f
"Exiting gracefully: Signal
{
signum
}
"
)
self
.
KEEP_PROCESSING
=
False
signal_handler
=
SignalHandler
()
class
TextGenerationService
(
generate_pb2_grpc
.
TextGenerationServiceServicer
):
class
TextGenerationService
(
generate_pb2_grpc
.
TextGenerationServiceServicer
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -231,11 +247,8 @@ def serve(
...
@@ -231,11 +247,8 @@ def serve(
logger
.
info
(
"Server started at {}"
.
format
(
local_url
))
logger
.
info
(
"Server started at {}"
.
format
(
local_url
))
try
:
while
signal_handler
.
KEEP_PROCESSING
:
await
server
.
wait_for_termination
()
await
asyncio
.
sleep
(
0.5
)
except
KeyboardInterrupt
:
logger
.
info
(
"Signal received. Shutting down"
)
await
server
.
stop
(
0
)
asyncio
.
run
(
asyncio
.
run
(
serve_inner
(
serve_inner
(
...
...
server/text_generation_server/utils/dist.py
View file @
5a1cf2f0
...
@@ -57,6 +57,13 @@ def initialize_torch_distributed():
...
@@ -57,6 +57,13 @@ def initialize_torch_distributed():
options
.
is_high_priority_stream
=
True
options
.
is_high_priority_stream
=
True
options
.
_timeout
=
timedelta
(
seconds
=
60
)
options
.
_timeout
=
timedelta
(
seconds
=
60
)
else
:
else
:
try
:
import
oneccl_bindings_for_pytorch
backend
=
"ccl"
if
os
.
getenv
(
"CCL_WORKER_COUNT"
,
None
)
is
None
:
os
.
environ
[
"CCL_WORKER_COUNT"
]
=
str
(
1
)
except
ImportError
:
backend
=
"gloo"
backend
=
"gloo"
options
=
None
options
=
None
...
...
server/text_generation_server/utils/flash_attn.py
View file @
5a1cf2f0
This diff is collapsed.
Click to expand it.
server/text_generation_server/utils/import_utils.py
View file @
5a1cf2f0
import
torch
import
torch
def
is_xpu_available
():
try
:
import
intel_extension_for_pytorch
except
ImportError
:
return
False
return
hasattr
(
torch
,
"xpu"
)
and
torch
.
xpu
.
is_available
()
IS_ROCM_SYSTEM
=
torch
.
version
.
hip
is
not
None
IS_ROCM_SYSTEM
=
torch
.
version
.
hip
is
not
None
IS_CUDA_SYSTEM
=
torch
.
version
.
cuda
is
not
None
IS_CUDA_SYSTEM
=
torch
.
version
.
cuda
is
not
None
IS_XPU_SYSTEM
=
is_xpu_available
()
server/text_generation_server/utils/layers.py
View file @
5a1cf2f0
This diff is collapsed.
Click to expand it.
server/text_generation_server/utils/logits_process.py
View file @
5a1cf2f0
...
@@ -143,13 +143,16 @@ class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
...
@@ -143,13 +143,16 @@ class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
score
=
torch
.
gather
(
scores
,
1
,
input_ids
)
score
=
torch
.
gather
(
scores
,
1
,
input_ids
)
# if score < 0 then penalty has to be multiplied to reduce the previous token probability
# if score < 0 then penalty has to be multiplied to reduce the previous token probability
score
=
-
torch
.
where
(
score
<
0
,
score
*
self
.
penalty
,
score
/
self
.
penalty
)
score
=
-
torch
.
where
(
score
<
0
,
score
*
self
.
penalty
,
score
/
self
.
penalty
)
# set score to 0 where input_ids is a padding token
score
*=
input_ids
.
ne
(
0
)
return
scores
.
scatter_add_
(
1
,
input_ids
,
score
)
return
scores
.
scatter_add_
(
1
,
input_ids
,
score
)
class
HeterogeneousFrequencyPenaltyLogitsProcessor
(
LogitsProcessor
):
class
HeterogeneousFrequencyPenaltyLogitsProcessor
(
LogitsProcessor
):
r
"""
r
"""
Frequency penalty as defined by OpenAI
Frequency penalty as defined by OpenAI in
https://platform.openai.com/docs/guides/text-generation/parameter-details
Args:
Args:
frequency_penalty (`List[float]`):
frequency_penalty (`List[float]`):
...
@@ -163,13 +166,19 @@ class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
...
@@ -163,13 +166,19 @@ class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
).
unsqueeze
(
1
)
).
unsqueeze
(
1
)
def
__call__
(
self
,
input_ids
:
torch
.
Tensor
,
scores
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
__call__
(
self
,
input_ids
:
torch
.
Tensor
,
scores
:
torch
.
Tensor
)
->
torch
.
Tensor
:
score
=
torch
.
gather
(
scores
,
1
,
input_ids
)
batch_size
,
input_size
=
input_ids
.
size
()
# if score < 0 then penalty has to be multiplied to reduce the previous token probability
vocab_size
=
scores
.
size
(
1
)
score
=
-
torch
.
where
(
score
<
0
,
score
*
self
.
penalty_tensor
,
score
/
self
.
penalty_tensor
# Calculate the frequency for each token so far
token_freq
=
torch
.
zeros
(
batch_size
,
vocab_size
,
device
=
input_ids
.
device
)
token_freq
.
scatter_add_
(
1
,
input_ids
,
torch
.
ones_like
(
input_ids
,
dtype
=
torch
.
float
)
)
)
token_freq
/=
input_size
return
scores
.
scatter_add_
(
1
,
input_ids
,
score
)
# Apply the frequency penalty to logits
scores
-=
token_freq
*
self
.
penalty_tensor
return
scores
def
filter
(
self
,
indices
):
def
filter
(
self
,
indices
):
self
.
penalty
=
[
self
.
penalty
[
i
]
for
i
in
indices
]
self
.
penalty
=
[
self
.
penalty
[
i
]
for
i
in
indices
]
...
...
server/text_generation_server/utils/paged_attention.py
View file @
5a1cf2f0
This diff is collapsed.
Click to expand it.
server/text_generation_server/utils/tokens.py
View file @
5a1cf2f0
This diff is collapsed.
Click to expand it.
server/text_generation_server/utils/weights.py
View file @
5a1cf2f0
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment