Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
5a1cf2f0
Commit
5a1cf2f0
authored
May 22, 2024
by
huangwb
Browse files
Merge tag 'v2.0.2' into dev-rocm
parents
24f58bb6
6073ece4
Changes
70
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
338 additions
and
126 deletions
+338
-126
server/text_generation_server/models/vlm_causal_lm.py
server/text_generation_server/models/vlm_causal_lm.py
+80
-36
server/text_generation_server/server.py
server/text_generation_server/server.py
+18
-5
server/text_generation_server/utils/dist.py
server/text_generation_server/utils/dist.py
+8
-1
server/text_generation_server/utils/flash_attn.py
server/text_generation_server/utils/flash_attn.py
+87
-53
server/text_generation_server/utils/import_utils.py
server/text_generation_server/utils/import_utils.py
+11
-0
server/text_generation_server/utils/layers.py
server/text_generation_server/utils/layers.py
+53
-5
server/text_generation_server/utils/logits_process.py
server/text_generation_server/utils/logits_process.py
+15
-6
server/text_generation_server/utils/paged_attention.py
server/text_generation_server/utils/paged_attention.py
+29
-5
server/text_generation_server/utils/tokens.py
server/text_generation_server/utils/tokens.py
+24
-9
server/text_generation_server/utils/weights.py
server/text_generation_server/utils/weights.py
+13
-6
No files found.
server/text_generation_server/models/vlm_causal_lm.py
View file @
5a1cf2f0
...
...
@@ -64,6 +64,46 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
return
height
//
patch_size
,
width
//
patch_size
def
image_text_replacement
(
image_input
,
config
,
image_id
)
->
str
:
if
config
.
model_type
==
"idefics2"
:
# TODO technically depends on image splitting which is not implemented.
num_features
=
320
return
(
"<fake_token_around_image>"
+
"<image>"
*
num_features
+
"<fake_token_around_image>"
)
elif
config
.
model_type
==
"llava_next"
:
height
,
width
=
image_input
[
"image_sizes"
][
image_id
]
num_features
=
get_number_of_features
(
height
,
width
,
config
)
from
loguru
import
logger
logger
.
info
(
f
"Found
{
num_features
}
in image of resolution
{
height
}
x
{
width
}
"
)
return
"<image>"
*
num_features
else
:
raise
RuntimeError
(
f
"Unknown config
{
config
.
model_type
}
for multimodal"
)
def
get_unpadded_features
(
height
:
int
,
width
:
int
,
npatches
:
int
,
num_patch_height
:
int
,
num_patch_width
:
int
)
->
Tuple
[
int
,
int
]:
current_height
=
npatches
*
num_patch_height
current_width
=
npatches
*
num_patch_width
aspect_ratio
:
float
=
width
/
height
current_aspect_ratio
:
float
=
current_width
/
current_height
if
aspect_ratio
>
current_aspect_ratio
:
new_height
=
(
height
*
current_width
)
//
width
current_height
=
new_height
else
:
new_width
=
(
width
*
current_height
)
//
height
current_width
=
new_width
unpadded_features
=
current_height
*
current_width
newline_features
=
current_height
return
(
unpadded_features
,
newline_features
)
def
get_number_of_features
(
height
:
int
,
width
:
int
,
config
)
->
int
:
# From config
# Hardcoded for CLIP for now
...
...
@@ -81,12 +121,9 @@ def get_number_of_features(height: int, width: int, config) -> int:
image_grid_pinpoints
,
image_size
,
)
height_of_patch
=
math
.
ceil
(
height
/
width
*
npatches
)
unpadded_features
=
npatches
*
height_of_patch
*
num_patch_height
*
num_patch_width
# They are only added after width
newline_features
=
height_of_patch
*
num_patch_width
unpadded_features
,
newline_features
=
get_unpadded_features
(
height
,
width
,
npatches
,
num_patch_height
,
num_patch_width
)
# The base patch covers the entire image
base_features
=
npatches
**
2
return
unpadded_features
+
newline_features
+
base_features
...
...
@@ -99,12 +136,9 @@ def load_data_uri(image_uri: str) -> Image.Image:
return
image
# assert get_number_of_features(889, 1024) == 2634, f"{get_number_of_features(889, 1024)}"
# assert get_number_of_features(640, 640) == 2928
class
VlmCausalLMBatch
(
FlashMistralBatch
):
pixel_values
:
Optional
[
List
[
torch
.
Tensor
]]
pixel_attention_mask
:
Optional
[
List
[
torch
.
Tensor
]]
image_sizes
:
Optional
[
List
[
Tuple
[
int
,
int
]]]
@
classmethod
...
...
@@ -112,6 +146,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
def
concatenate
(
cls
,
batches
):
batch
=
super
(
VlmCausalLMBatch
,
cls
).
concatenate
(
batches
)
batch
.
pixel_values
=
None
batch
.
pixel_attention_mask
=
None
batch
.
image_sizes
=
None
return
batch
...
...
@@ -119,6 +154,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
def
filter
(
self
,
request_ids
:
List
[
int
]):
batch
=
super
().
filter
(
request_ids
)
batch
.
pixel_values
=
None
batch
.
pixel_attention_mask
=
None
batch
.
image_sizes
=
None
return
batch
...
...
@@ -130,6 +166,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
for
r
in
requests
:
chunks
=
split
(
r
.
inputs
)
full_text
=
""
image_id
=
0
for
chunk
in
chunks
:
if
chunk
[
"type"
]
==
"text"
:
full_text
+=
chunk
[
"content"
]
...
...
@@ -147,9 +184,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
"Cannot process input image not starting with data:"
)
image_input
=
processor
.
image_processor
(
image
,
return_tensors
=
"pt"
)
height
,
width
=
image_input
[
"image_sizes"
][
0
]
num_features
=
get_number_of_features
(
height
,
width
,
config
)
full_text
+=
"<image>"
*
num_features
full_text
+=
image_text_replacement
(
image_input
,
config
,
image_id
)
image_inputs
.
append
(
image_input
)
else
:
raise
RuntimeError
(
f
"Invalid chunk type
{
chunk
[
'type'
]
}
"
)
...
...
@@ -161,12 +196,21 @@ class VlmCausalLMBatch(FlashMistralBatch):
batch_inputs
,
truncation
=
True
,
max_length
=
max_truncation
)[
"input_ids"
]
if
image_inputs
:
image_inputs
=
{
image_input
=
image_inputs
[
0
]
new_image_inputs
=
{
"pixel_values"
:
torch
.
cat
(
[
img
[
"pixel_values"
]
for
img
in
image_inputs
],
dim
=
0
),
"image_sizes"
:
torch
.
cat
([
img
[
"image_sizes"
]
for
img
in
image_inputs
]),
}
if
"pixel_attention_mask"
in
image_input
:
new_image_inputs
[
"pixel_attention_mask"
]
=
torch
.
cat
(
[
img
[
"pixel_attention_mask"
]
for
img
in
image_inputs
],
dim
=
0
)
if
"image_sizes"
in
image_input
:
new_image_inputs
[
"image_sizes"
]
=
torch
.
cat
(
[
img
[
"image_sizes"
]
for
img
in
image_inputs
],
dim
=
0
)
image_inputs
=
new_image_inputs
else
:
image_inputs
=
None
return
batch_tokenized_inputs
,
image_inputs
...
...
@@ -187,9 +231,19 @@ class VlmCausalLMBatch(FlashMistralBatch):
batch
=
cls
.
from_tokenized
(
pb
,
tokenizer
,
batch_tokenized_inputs
,
dtype
,
device
)
if
image_inputs
is
not
None
:
batch
.
pixel_values
=
image_inputs
[
"pixel_values"
].
to
(
device
=
device
)
batch
.
image_sizes
=
image_inputs
[
"image_sizes"
].
to
(
device
=
device
)
if
"pixel_attention_mask"
in
image_inputs
:
batch
.
pixel_attention_mask
=
image_inputs
[
"pixel_attention_mask"
].
to
(
device
=
device
)
else
:
batch
.
pixel_attention_mask
=
None
if
"image_sizes"
in
image_inputs
:
batch
.
image_sizes
=
image_inputs
[
"image_sizes"
].
to
(
device
=
device
)
else
:
batch
.
image_sizes
=
None
else
:
batch
.
pixel_values
=
None
batch
.
pixel_attention_mask
=
None
batch
.
image_sizes
=
None
return
batch
...
...
@@ -199,16 +253,6 @@ class VlmCausalLM(BaseFlashMistral):
def
batch_type
(
self
)
->
Type
[
VlmCausalLMBatch
]:
return
VlmCausalLMBatch
def
get_layer_config
(
self
,
model
)
->
Tuple
[
int
,
int
,
int
]:
return
(
len
(
model
.
language_model
.
model
.
layers
),
model
.
language_model
.
model
.
num_key_value_heads
,
model
.
language_model
.
model
.
head_size
,
)
def
max_past
(
self
)
->
Optional
[
int
]:
return
getattr
(
self
.
model
.
language_model
,
"max_past"
,
None
)
def
forward
(
self
,
batch
:
VlmCausalLMBatch
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
...
...
@@ -270,17 +314,14 @@ class VlmCausalLM(BaseFlashMistral):
max_s
=
min
(
self
.
max_past
(),
max_s
)
bs
=
input_ids
.
shape
[
0
]
padded_bs
=
bs
if
bs
==
3
:
padded_bs
=
4
elif
3
<
bs
<=
8
:
padded_bs
=
8
elif
bs
>
8
:
padded_bs
=
(
bs
+
7
)
//
8
*
8
# Try to find an associated cuda graph
cuda_graph
=
self
.
cuda_graphs
.
get
(
padded_bs
,
None
)
bs
=
input_ids
.
shape
[
0
]
sorted_padded_bs
=
sorted
([
k
for
k
in
self
.
cuda_graphs
.
keys
()
if
k
>=
bs
])
if
sorted_padded_bs
:
# Get associated cuda graph
cuda_graph
=
self
.
cuda_graphs
[
sorted_padded_bs
[
0
]]
else
:
cuda_graph
=
None
if
cu_seqlen_prefill
is
not
None
or
cuda_graph
is
None
:
logits
,
speculative_logits
=
self
.
model
.
forward
(
input_ids
=
input_ids
,
...
...
@@ -294,12 +335,15 @@ class VlmCausalLM(BaseFlashMistral):
prefill_cache_indices
=
batch
.
prefill_cache_indices
,
lm_head_indices
=
lm_head_indices
,
pixel_values
=
batch
.
pixel_values
,
pixel_attention_mask
=
batch
.
pixel_attention_mask
,
image_sizes
=
batch
.
image_sizes
,
)
if
batch
.
prefill_cache_indices
is
not
None
:
batch
.
prefill_cache_indices
=
None
if
batch
.
pixel_values
is
not
None
:
batch
.
pixel_values
=
None
if
batch
.
pixel_attention_mask
is
not
None
:
batch
.
pixel_attention_mask
=
None
if
batch
.
image_sizes
is
not
None
:
batch
.
image_sizes
=
None
return
logits
,
speculative_logits
...
...
server/text_generation_server/server.py
View file @
5a1cf2f0
...
...
@@ -2,6 +2,7 @@ import asyncio
import
os
import
torch
import
time
import
signal
from
grpc
import
aio
from
loguru
import
logger
...
...
@@ -19,6 +20,21 @@ from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
from
text_generation_server.models.idefics_causal_lm
import
IdeficsCausalLMBatch
class
SignalHandler
:
KEEP_PROCESSING
=
True
def
__init__
(
self
):
signal
.
signal
(
signal
.
SIGINT
,
self
.
exit_gracefully
)
signal
.
signal
(
signal
.
SIGTERM
,
self
.
exit_gracefully
)
def
exit_gracefully
(
self
,
signum
,
frame
):
print
(
f
"Exiting gracefully: Signal
{
signum
}
"
)
self
.
KEEP_PROCESSING
=
False
signal_handler
=
SignalHandler
()
class
TextGenerationService
(
generate_pb2_grpc
.
TextGenerationServiceServicer
):
def
__init__
(
self
,
...
...
@@ -231,11 +247,8 @@ def serve(
logger
.
info
(
"Server started at {}"
.
format
(
local_url
))
try
:
await
server
.
wait_for_termination
()
except
KeyboardInterrupt
:
logger
.
info
(
"Signal received. Shutting down"
)
await
server
.
stop
(
0
)
while
signal_handler
.
KEEP_PROCESSING
:
await
asyncio
.
sleep
(
0.5
)
asyncio
.
run
(
serve_inner
(
...
...
server/text_generation_server/utils/dist.py
View file @
5a1cf2f0
...
...
@@ -57,7 +57,14 @@ def initialize_torch_distributed():
options
.
is_high_priority_stream
=
True
options
.
_timeout
=
timedelta
(
seconds
=
60
)
else
:
backend
=
"gloo"
try
:
import
oneccl_bindings_for_pytorch
backend
=
"ccl"
if
os
.
getenv
(
"CCL_WORKER_COUNT"
,
None
)
is
None
:
os
.
environ
[
"CCL_WORKER_COUNT"
]
=
str
(
1
)
except
ImportError
:
backend
=
"gloo"
options
=
None
if
WORLD_SIZE
==
1
:
...
...
server/text_generation_server/utils/flash_attn.py
View file @
5a1cf2f0
...
...
@@ -2,69 +2,81 @@ import os
import
torch
from
loguru
import
logger
import
math
from
text_generation_server.utils.import_utils
import
IS_CUDA_SYSTEM
,
IS_ROCM_SYSTEM
from
text_generation_server.utils.import_utils
import
(
IS_CUDA_SYSTEM
,
IS_ROCM_SYSTEM
,
IS_XPU_SYSTEM
,
)
if
os
.
getenv
(
"USE_FLASH_ATTENTION"
,
""
).
lower
()
==
"false"
:
raise
ImportError
(
"`USE_FLASH_ATTENTION` is false."
)
HAS_FLASH_ATTN
=
True
HAS_FLASH_ATTN_V2_CUDA
=
False
HAS_FLASH_ATTN_V2_ROCM
=
False
if
not
torch
.
cuda
.
is_available
()
:
raise
ImportError
(
"CUDA is not available"
)
if
IS_XPU_SYSTEM
:
import
intel_extension_for_pytorch
as
ipex
major
,
minor
=
torch
.
cuda
.
get_device_capability
()
is_sm75
=
major
==
7
and
minor
==
5
is_sm8x
=
major
==
8
and
minor
>=
0
is_sm90
=
major
==
9
and
minor
==
0
if
IS_CUDA_SYSTEM
or
IS_ROCM_SYSTEM
:
if
not
torch
.
cuda
.
is_available
():
raise
ImportError
(
"CUDA is not available"
)
HAS_FLASH_ATTN
=
False
HAS_FLASH_ATTN_V2_CUDA
=
False
HAS_FLASH_ATTN_V2_ROCM
=
False
try
:
major
,
minor
=
torch
.
cuda
.
get_device_capability
()
is_sm75
=
major
==
7
and
minor
==
5
is_sm8x
=
major
==
8
and
minor
>=
0
is_sm90
=
major
==
9
and
minor
==
0
HAS_FLASH_ATTN
=
False
HAS_FLASH_ATTN_V2_CUDA
=
False
HAS_FLASH_ATTN_V2_ROCM
=
False
try
:
import
flash_attn_2_cuda
except
ImportError
:
architecture_suffix
=
""
if
IS_CUDA_SYSTEM
:
architecture_suffix
=
"-cuda"
try
:
import
flash_attn_2_cuda
except
ImportError
:
architecture_suffix
=
""
if
IS_CUDA_SYSTEM
:
architecture_suffix
=
"-cuda"
elif
IS_ROCM_SYSTEM
:
architecture_suffix
=
"-rocm"
raise
ImportError
(
"Flash Attention V2 is not installed.
\n
"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
f
"or install flash attention v2 with `cd server && make install install-flash-attention-v2
{
architecture_suffix
}
`"
)
if
not
(
is_sm8x
or
is_sm90
):
raise
ImportError
(
f
"GPU with CUDA capability
{
major
}
{
minor
}
is not supported for "
"Flash Attention V2"
)
HAS_FLASH_ATTN_V2_CUDA
=
IS_CUDA_SYSTEM
HAS_FLASH_ATTN_V2_ROCM
=
IS_ROCM_SYSTEM
except
ImportError
as
e
:
try
:
import
flash_attn_cuda
except
ImportError
:
raise
ImportError
(
"Flash Attention is not installed.
\n
"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
)
from
e
if
IS_CUDA_SYSTEM
and
not
(
is_sm75
or
is_sm8x
or
is_sm90
):
raise
ImportError
(
f
"GPU with CUDA capability
{
major
}
{
minor
}
is not supported"
)
from
e
elif
IS_ROCM_SYSTEM
:
architecture_suffix
=
"-rocm"
raise
ImportError
(
"Flash Attention V2 is not installed.
\n
"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
f
"or install flash attention v2 with `cd server && make install install-flash-attention-v2
{
architecture_suffix
}
`"
)
if
not
(
is_sm8x
or
is_sm90
)
and
IS_CUDA_SYSTEM
:
raise
ImportError
(
f
"GPU with CUDA capability
{
major
}
{
minor
}
is not supported for "
"Flash Attention V2"
)
HAS_FLASH_ATTN_V2_CUDA
=
IS_CUDA_SYSTEM
HAS_FLASH_ATTN_V2_ROCM
=
IS_ROCM_SYSTEM
except
ImportError
as
e
:
try
:
import
flash_attn_cuda
except
ImportError
:
raise
ImportError
(
"Flash Attention is not installed.
\n
"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
)
from
e
if
IS_CUDA_SYSTEM
and
not
(
is_sm75
or
is_sm8x
or
is_sm90
):
raise
ImportError
(
f
"GPU with CUDA capability
{
major
}
{
minor
}
is not supported"
)
from
e
elif
IS_ROCM_SYSTEM
:
for
idx
in
range
(
torch
.
cuda
.
device_count
()):
if
"MI210"
not
in
torch
.
cuda
.
get_device_name
(
idx
)
and
"MI250"
not
in
torch
.
cuda
.
get_device_name
(
idx
):
raise
ImportError
(
f
"AMD GPU
{
torch
.
cuda
.
get_device_name
(
idx
)
}
does not support flash-attention"
)
for
idx
in
range
(
torch
.
cuda
.
device_count
()):
if
"MI210"
not
in
torch
.
cuda
.
get_device_name
(
idx
)
and
"MI250"
not
in
torch
.
cuda
.
get_device_name
(
idx
):
raise
ImportError
(
f
"AMD GPU
{
torch
.
cuda
.
get_device_name
(
idx
)
}
does not support flash-attention"
)
logger
.
warning
(
f
"Unable to use Flash Attention V2:
{
e
}
"
)
HAS_FLASH_ATTN
=
True
logger
.
warning
(
f
"Unable to use Flash Attention V2:
{
e
}
"
)
HAS_FLASH_ATTN
=
True
def
attention
(
...
...
@@ -80,6 +92,28 @@ def attention(
if
window_size_left
<=
0
and
window_size_left
!=
-
1
:
raise
ValueError
(
"`window_size_left` must be > 0 or -1"
)
if
IS_XPU_SYSTEM
:
if
window_size_left
!=
-
1
:
raise
ValueError
(
f
"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left=
{
window_size_left
}
)."
)
return
ipex
.
llm
.
functional
.
varlen_attention
(
q
,
k
,
v
,
out
,
cu_seqlens
,
cu_seqlens
,
max_s
,
max_s
,
0.0
,
softmax_scale
,
False
,
True
,
False
,
None
,
)
if
HAS_FLASH_ATTN_V2_CUDA
:
return
flash_attn_2_cuda
.
varlen_fwd
(
q
,
...
...
server/text_generation_server/utils/import_utils.py
View file @
5a1cf2f0
import
torch
def
is_xpu_available
():
try
:
import
intel_extension_for_pytorch
except
ImportError
:
return
False
return
hasattr
(
torch
,
"xpu"
)
and
torch
.
xpu
.
is_available
()
IS_ROCM_SYSTEM
=
torch
.
version
.
hip
is
not
None
IS_CUDA_SYSTEM
=
torch
.
version
.
cuda
is
not
None
IS_XPU_SYSTEM
=
is_xpu_available
()
server/text_generation_server/utils/layers.py
View file @
5a1cf2f0
...
...
@@ -8,6 +8,8 @@ from typing import List, Tuple, Optional
from
loguru
import
logger
from
functools
import
lru_cache
from
text_generation_server.utils.speculate
import
get_speculate
HAS_BITS_AND_BYTES
=
True
try
:
import
bitsandbytes
as
bnb
...
...
@@ -18,7 +20,14 @@ except ImportError:
from
accelerate
import
init_empty_weights
from
text_generation_server.utils.gptq.quant_linear
import
QuantLinear
from
text_generation_server.utils.import_utils
import
IS_CUDA_SYSTEM
,
IS_ROCM_SYSTEM
from
text_generation_server.utils.import_utils
import
(
IS_CUDA_SYSTEM
,
IS_ROCM_SYSTEM
,
IS_XPU_SYSTEM
,
)
if
IS_XPU_SYSTEM
:
import
intel_extension_for_pytorch
as
ipex
HAS_AWQ
=
True
try
:
...
...
@@ -437,7 +446,7 @@ class MedusaModel(torch.nn.Module):
self
.
heads
=
torch
.
nn
.
ModuleList
(
[
MedusaHead
(
config
,
medusa_config
,
prefix
=
f
"
{
i
}
"
,
weights
=
weights
)
for
i
in
range
(
medusa_config
[
"medusa_num_heads"
]
)
for
i
in
range
(
get_speculate
()
)
]
)
...
...
@@ -534,7 +543,7 @@ class MedusaHeadV2(nn.Module):
)
routing
[
k
]
=
filename
self
.
n_medusa_heads
=
medusa_config
[
"medusa_num_heads"
]
self
.
n_medusa_heads
=
get_speculate
()
assert
medusa_config
[
"medusa_num_layers"
]
==
1
self
.
linear
=
TensorParallelColumnLinear
.
load_multi
(
...
...
@@ -696,6 +705,19 @@ class TensorParallelHead(SuperLayer):
class
TensorParallelColumnLinear
(
SuperLayer
):
@
classmethod
def
load_gate_up
(
cls
,
config
,
prefix
:
str
,
weights
,
bias
:
bool
):
"""Specific method when the QKV was joined after the fact"""
weight
=
weights
.
get_weights_col_packed_gate_up
(
prefix
,
quantize
=
config
.
quantize
)
if
bias
:
raise
NotImplementedError
(
"packed_gate_up only implemented without bias"
)
else
:
bias
=
None
linear
=
get_linear
(
weight
,
bias
,
config
.
quantize
)
return
cls
(
linear
)
@
classmethod
def
load_qkv
(
cls
,
config
,
prefix
:
str
,
weights
,
bias
:
bool
):
"""Specific method when the QKV was joined after the fact"""
...
...
@@ -799,7 +821,15 @@ try:
class
FastLayerNorm
(
nn
.
LayerNorm
):
def
forward
(
self
,
hidden_states
,
residual
=
None
):
if
hidden_states
.
shape
[
-
1
]
>
8192
or
IS_ROCM_SYSTEM
:
if
IS_XPU_SYSTEM
:
res_out
=
hidden_states
out
=
ipex
.
llm
.
functional
.
add_layer_norm
(
residual
,
hidden_states
,
self
.
weight
,
self
.
bias
,
self
.
eps
,
True
)
if
residual
is
not
None
:
res_out
=
residual
return
out
,
res_out
elif
hidden_states
.
shape
[
-
1
]
>
8192
or
IS_ROCM_SYSTEM
:
if
residual
is
not
None
:
hidden_states
+=
residual
residual
=
hidden_states
...
...
@@ -845,7 +875,20 @@ try:
return
cls
(
weight
,
eps
)
def
forward
(
self
,
hidden_states
,
residual
=
None
):
if
hidden_states
.
shape
[
-
1
]
>
8192
:
if
IS_XPU_SYSTEM
:
residual_out
=
hidden_states
out
=
ipex
.
llm
.
functional
.
add_rms_norm
(
residual
,
hidden_states
,
self
.
weight
,
None
,
self
.
variance_epsilon
,
True
,
)
if
residual
is
not
None
:
residual_out
=
residual
return
out
,
residual_out
elif
hidden_states
.
shape
[
-
1
]
>
8192
:
if
residual
is
not
None
:
hidden_states
+=
residual
residual
=
hidden_states
...
...
@@ -971,6 +1014,10 @@ try:
# Inplace operation, updating query and key.
pos_encoding_ops
.
rotary_embedding
(
query
,
key
,
head_size
,
cos
,
sin
,
True
)
elif
IS_XPU_SYSTEM
:
ipex
.
llm
.
functional
.
rotary_embedding
(
query
,
key
,
sin
,
cos
,
query
.
size
(
-
1
),
True
)
else
:
raise
ValueError
(
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
...
...
@@ -1090,6 +1137,7 @@ try:
cos
=
torch
.
index_select
(
self
.
_cos_cached
,
0
,
position_ids
)
sin
=
torch
.
index_select
(
self
.
_sin_cached
,
0
,
position_ids
)
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
return
cos
.
unsqueeze
(
1
),
sin
.
unsqueeze
(
1
)
...
...
server/text_generation_server/utils/logits_process.py
View file @
5a1cf2f0
...
...
@@ -143,13 +143,16 @@ class FrequencyPenaltyLogitsProcessor(LogitsProcessor):
score
=
torch
.
gather
(
scores
,
1
,
input_ids
)
# if score < 0 then penalty has to be multiplied to reduce the previous token probability
score
=
-
torch
.
where
(
score
<
0
,
score
*
self
.
penalty
,
score
/
self
.
penalty
)
# set score to 0 where input_ids is a padding token
score
*=
input_ids
.
ne
(
0
)
return
scores
.
scatter_add_
(
1
,
input_ids
,
score
)
class
HeterogeneousFrequencyPenaltyLogitsProcessor
(
LogitsProcessor
):
r
"""
Frequency penalty as defined by OpenAI
Frequency penalty as defined by OpenAI in
https://platform.openai.com/docs/guides/text-generation/parameter-details
Args:
frequency_penalty (`List[float]`):
...
...
@@ -163,13 +166,19 @@ class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
).
unsqueeze
(
1
)
def
__call__
(
self
,
input_ids
:
torch
.
Tensor
,
scores
:
torch
.
Tensor
)
->
torch
.
Tensor
:
score
=
torch
.
gather
(
scores
,
1
,
input_ids
)
# if score < 0 then penalty has to be multiplied to reduce the previous token probability
score
=
-
torch
.
where
(
score
<
0
,
score
*
self
.
penalty_tensor
,
score
/
self
.
penalty_tensor
batch_size
,
input_size
=
input_ids
.
size
()
vocab_size
=
scores
.
size
(
1
)
# Calculate the frequency for each token so far
token_freq
=
torch
.
zeros
(
batch_size
,
vocab_size
,
device
=
input_ids
.
device
)
token_freq
.
scatter_add_
(
1
,
input_ids
,
torch
.
ones_like
(
input_ids
,
dtype
=
torch
.
float
)
)
token_freq
/=
input_size
return
scores
.
scatter_add_
(
1
,
input_ids
,
score
)
# Apply the frequency penalty to logits
scores
-=
token_freq
*
self
.
penalty_tensor
return
scores
def
filter
(
self
,
indices
):
self
.
penalty
=
[
self
.
penalty
[
i
]
for
i
in
indices
]
...
...
server/text_generation_server/utils/paged_attention.py
View file @
5a1cf2f0
import
torch
from
text_generation_server.utils.import_utils
import
IS_CUDA_SYSTEM
,
IS_ROCM_SYSTEM
from
loguru
import
logger
from
text_generation_server.utils.import_utils
import
(
IS_CUDA_SYSTEM
,
IS_ROCM_SYSTEM
,
IS_XPU_SYSTEM
,
)
_PARTITION_SIZE
=
512
if
IS_XPU_SYSTEM
:
import
intel_extension_for_pytorch
as
ipex
def
reshape_and_cache
(
key
:
torch
.
Tensor
,
...
...
@@ -22,8 +27,11 @@ def reshape_and_cache(
elif
IS_ROCM_SYSTEM
:
from
vllm
import
cache_ops
cache_ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slots
.
int
())
# cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots)
cache_ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slots
)
elif
IS_XPU_SYSTEM
:
ipex
.
llm
.
modules
.
PagedAttention
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slots
)
else
:
raise
ValueError
(
"vllm is not supported on your system"
)
...
...
@@ -60,6 +68,22 @@ def attention(
block_size
=
value_cache
.
shape
[
3
]
num_seqs
,
num_heads
,
head_size
=
query
.
shape
max_num_partitions
=
(
max_s
+
_PARTITION_SIZE
-
1
)
//
_PARTITION_SIZE
if
IS_XPU_SYSTEM
:
query
=
query
.
contiguous
()
return
ipex
.
llm
.
modules
.
PagedAttention
.
single_query_cached_kv_attention
(
out
,
query
,
key_cache
,
value_cache
,
kv_head_mapping
,
softmax_scale
,
block_tables
,
input_lengths
,
block_size
,
max_s
,
None
,
)
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
...
...
server/text_generation_server/utils/tokens.py
View file @
5a1cf2f0
import
re
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
,
Set
,
Union
import
math
import
torch
...
...
@@ -143,12 +143,22 @@ class StopSequenceCriteria:
class
StoppingCriteria
:
def
__init__
(
self
,
eos_token_id
:
int
,
eos_token_id
s
:
Optional
[
Union
[
Set
[
int
],
int
]]
,
stop_sequence_criterias
:
List
[
StopSequenceCriteria
],
max_new_tokens
:
int
=
20
,
ignore_eos_token
:
bool
=
False
,
):
self
.
eos_token_id
=
eos_token_id
if
eos_token_ids
is
None
:
eos_token_ids
=
set
()
elif
isinstance
(
eos_token_ids
,
int
):
eos_token_ids
=
set
([
eos_token_ids
])
elif
isinstance
(
eos_token_ids
,
set
):
eos_token_ids
=
eos_token_ids
else
:
raise
RuntimeError
(
f
"eos_token_ids is of invalid type
{
type
(
eos_token_ids
)
}
, expected int, None or set[int]"
)
self
.
eos_token_ids
=
eos_token_ids
self
.
stop_sequence_criterias
=
stop_sequence_criterias
self
.
max_new_tokens
=
max_new_tokens
self
.
current_tokens
=
0
...
...
@@ -160,7 +170,10 @@ class StoppingCriteria:
if
self
.
current_tokens
>=
self
.
max_new_tokens
:
return
True
,
FinishReason
.
FINISH_REASON_LENGTH
if
not
self
.
ignore_eos_token
and
last_token
==
self
.
eos_token_id
:
if
isinstance
(
last_token
,
torch
.
Tensor
):
last_token
=
last_token
.
item
()
if
not
self
.
ignore_eos_token
and
last_token
in
self
.
eos_token_ids
:
return
True
,
FinishReason
.
FINISH_REASON_EOS_TOKEN
if
self
.
stop_sequence_criterias
:
...
...
@@ -184,8 +197,10 @@ class StoppingCriteria:
stop_sequence_criterias
=
[
StopSequenceCriteria
(
sequence
)
for
sequence
in
pb
.
stop_sequences
]
# TODO Hack because eos_token_id cannot be what we want.
eos_token_id
=
getattr
(
tokenizer
,
"_eos_token_ids"
,
tokenizer
.
eos_token_id
)
return
StoppingCriteria
(
tokenizer
.
eos_token_id
,
eos_token_id
,
stop_sequence_criterias
,
pb
.
max_new_tokens
,
pb
.
ignore_eos_token
,
...
...
@@ -273,7 +288,7 @@ class HeterogeneousNextTokenChooser:
else
None
)
if
any
(
[
x
!=
1.0
for
x
in
temperature
]
):
if
any
(
x
!=
1.0
for
x
in
temperature
):
do_sample
=
[
sample
or
x
!=
1.0
for
x
,
sample
in
zip
(
temperature
,
do_sample
)
]
...
...
@@ -281,15 +296,15 @@ class HeterogeneousNextTokenChooser:
HeterogeneousTemperatureLogitsWarper
(
temperature
,
dtype
,
device
)
)
if
any
(
[
x
!=
0
for
x
in
top_k
]
):
if
any
(
x
!=
0
for
x
in
top_k
):
do_sample
=
[
sample
or
x
!=
0
for
x
,
sample
in
zip
(
top_k
,
do_sample
)]
warpers
.
append
(
HeterogeneousTopKLogitsWarper
(
top_k
,
device
))
if
any
(
[
x
<
1.0
for
x
in
top_p
]
):
if
any
(
x
<
1.0
for
x
in
top_p
):
do_sample
=
[
sample
or
x
<
1.0
for
x
,
sample
in
zip
(
top_p
,
do_sample
)]
warpers
.
append
(
HeterogeneousTopPLogitsWarper
(
top_p
,
dtype
,
device
))
if
any
(
[
x
<
1.0
for
x
in
typical_p
]
):
if
any
(
x
<
1.0
for
x
in
typical_p
):
do_sample
=
[
sample
or
x
<
1.0
for
x
,
sample
in
zip
(
typical_p
,
do_sample
)]
warpers
.
append
(
HeterogeneousTypicalLogitsWarper
(
typical_p
,
dtype
,
device
))
...
...
server/text_generation_server/utils/weights.py
View file @
5a1cf2f0
...
...
@@ -141,6 +141,12 @@ class Weights:
return
weight
def
get_weights_col_packed_qkv
(
self
,
prefix
:
str
,
quantize
:
str
):
return
self
.
get_weights_col_packed
(
prefix
,
quantize
,
3
)
def
get_weights_col_packed_gate_up
(
self
,
prefix
:
str
,
quantize
:
str
):
return
self
.
get_weights_col_packed
(
prefix
,
quantize
,
2
)
def
get_weights_col_packed
(
self
,
prefix
:
str
,
quantize
:
str
,
blocks
:
int
):
"""
Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being
already alternating Q,K,V within the main tensor
...
...
@@ -181,8 +187,8 @@ class Weights:
else
:
slice_
=
self
.
_get_slice
(
f
"
{
prefix
}
.weight"
)
total_size
=
slice_
.
get_shape
()[
0
]
assert
total_size
%
3
==
0
,
"Prepacked
qkv
is not divisible by
3
"
single_size
=
total_size
//
3
assert
total_size
%
blocks
==
0
,
f
"Prepacked is not divisible by
{
blocks
}
"
single_size
=
total_size
//
blocks
world_size
=
self
.
process_group
.
size
()
rank
=
self
.
process_group
.
rank
()
...
...
@@ -192,10 +198,11 @@ class Weights:
block_size
=
single_size
//
world_size
start
=
rank
*
block_size
stop
=
(
rank
+
1
)
*
block_size
q
=
slice_
[
start
:
stop
]
k
=
slice_
[
start
+
single_size
:
stop
+
single_size
]
v
=
slice_
[
start
+
2
*
single_size
:
stop
+
2
*
single_size
]
weight
=
torch
.
cat
([
q
,
k
,
v
],
dim
=
0
)
tensors
=
[]
for
i
in
range
(
blocks
):
tensor
=
slice_
[
start
+
i
*
single_size
:
stop
+
i
*
single_size
]
tensors
.
append
(
tensor
)
weight
=
torch
.
cat
(
tensors
,
dim
=
0
)
weight
=
weight
.
to
(
device
=
self
.
device
)
weight
=
weight
.
to
(
dtype
=
self
.
dtype
)
return
weight
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment