Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
47954b81
Unverified
Commit
47954b81
authored
Sep 27, 2023
by
OlivierDehaene
Committed by
GitHub
Sep 27, 2023
Browse files
feat: format code (#1070)
parent
b32e9ce9
Changes
28
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
636 additions
and
231 deletions
+636
-231
clients/python/text_generation/client.py
clients/python/text_generation/client.py
+1
-1
clients/python/text_generation/types.py
clients/python/text_generation/types.py
+3
-1
integration-tests/models/test_flash_awq.py
integration-tests/models/test_flash_awq.py
+18
-9
integration-tests/models/test_flash_awq_sharded.py
integration-tests/models/test_flash_awq_sharded.py
+20
-3
integration-tests/models/test_idefics.py
integration-tests/models/test_idefics.py
+1
-3
server/tests/utils/test_tokens.py
server/tests/utils/test_tokens.py
+5
-2
server/text_generation_server/cli.py
server/text_generation_server/cli.py
+9
-4
server/text_generation_server/models/__init__.py
server/text_generation_server/models/__init__.py
+14
-14
server/text_generation_server/models/causal_lm.py
server/text_generation_server/models/causal_lm.py
+6
-3
server/text_generation_server/models/custom_modeling/bloom_modeling.py
...eneration_server/models/custom_modeling/bloom_modeling.py
+4
-1
server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
...ion_server/models/custom_modeling/flash_llama_modeling.py
+5
-1
server/text_generation_server/models/custom_modeling/idefics_image_processing.py
...server/models/custom_modeling/idefics_image_processing.py
+22
-8
server/text_generation_server/models/custom_modeling/idefics_modeling.py
...eration_server/models/custom_modeling/idefics_modeling.py
+274
-86
server/text_generation_server/models/custom_modeling/idefics_perceiver.py
...ration_server/models/custom_modeling/idefics_perceiver.py
+60
-29
server/text_generation_server/models/custom_modeling/idefics_processing.py
...ation_server/models/custom_modeling/idefics_processing.py
+35
-9
server/text_generation_server/models/custom_modeling/idefics_vision.py
...eneration_server/models/custom_modeling/idefics_vision.py
+81
-24
server/text_generation_server/models/custom_modeling/neox_modeling.py
...generation_server/models/custom_modeling/neox_modeling.py
+4
-1
server/text_generation_server/models/flash_causal_lm.py
server/text_generation_server/models/flash_causal_lm.py
+6
-3
server/text_generation_server/models/idefics_causal_lm.py
server/text_generation_server/models/idefics_causal_lm.py
+66
-28
server/text_generation_server/models/model.py
server/text_generation_server/models/model.py
+2
-1
No files found.
clients/python/text_generation/client.py
View file @
47954b81
...
@@ -137,7 +137,7 @@ class Client:
...
@@ -137,7 +137,7 @@ class Client:
typical_p
=
typical_p
,
typical_p
=
typical_p
,
watermark
=
watermark
,
watermark
=
watermark
,
decoder_input_details
=
decoder_input_details
,
decoder_input_details
=
decoder_input_details
,
top_n_tokens
=
top_n_tokens
top_n_tokens
=
top_n_tokens
,
)
)
request
=
Request
(
inputs
=
prompt
,
stream
=
False
,
parameters
=
parameters
)
request
=
Request
(
inputs
=
prompt
,
stream
=
False
,
parameters
=
parameters
)
...
...
clients/python/text_generation/types.py
View file @
47954b81
...
@@ -133,7 +133,9 @@ class Request(BaseModel):
...
@@ -133,7 +133,9 @@ class Request(BaseModel):
and
parameters
.
best_of
>
1
and
parameters
.
best_of
>
1
and
field_value
and
field_value
):
):
raise
ValidationError
(
"`best_of` != 1 is not supported when `stream` == True"
)
raise
ValidationError
(
"`best_of` != 1 is not supported when `stream` == True"
)
return
field_value
return
field_value
...
...
integration-tests/models/test_flash_awq.py
View file @
47954b81
...
@@ -3,7 +3,11 @@ import pytest
...
@@ -3,7 +3,11 @@ import pytest
@
pytest
.
fixture
(
scope
=
"module"
)
@
pytest
.
fixture
(
scope
=
"module"
)
def
flash_llama_awq_handle
(
launcher
):
def
flash_llama_awq_handle
(
launcher
):
with
launcher
(
"abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq"
,
num_shard
=
1
,
quantize
=
"awq"
)
as
handle
:
with
launcher
(
"abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq"
,
num_shard
=
1
,
quantize
=
"awq"
,
)
as
handle
:
yield
handle
yield
handle
...
@@ -12,6 +16,7 @@ async def flash_llama_awq(flash_llama_awq_handle):
...
@@ -12,6 +16,7 @@ async def flash_llama_awq(flash_llama_awq_handle):
await
flash_llama_awq_handle
.
health
(
300
)
await
flash_llama_awq_handle
.
health
(
300
)
return
flash_llama_awq_handle
.
client
return
flash_llama_awq_handle
.
client
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
@
pytest
.
mark
.
private
async
def
test_flash_llama_awq
(
flash_llama_awq
,
response_snapshot
):
async
def
test_flash_llama_awq
(
flash_llama_awq
,
response_snapshot
):
...
@@ -20,11 +25,13 @@ async def test_flash_llama_awq(flash_llama_awq, response_snapshot):
...
@@ -20,11 +25,13 @@ async def test_flash_llama_awq(flash_llama_awq, response_snapshot):
)
)
assert
response
.
details
.
generated_tokens
==
10
assert
response
.
details
.
generated_tokens
==
10
assert
response
.
generated_text
==
"
\n
What is the difference between Deep Learning and Machine"
assert
(
response
.
generated_text
==
"
\n
What is the difference between Deep Learning and Machine"
)
assert
response
==
response_snapshot
assert
response
==
response_snapshot
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
@
pytest
.
mark
.
private
async
def
test_flash_llama_awq_all_params
(
flash_llama_awq
,
response_snapshot
):
async
def
test_flash_llama_awq_all_params
(
flash_llama_awq
,
response_snapshot
):
...
@@ -49,16 +56,18 @@ async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot):
...
@@ -49,16 +56,18 @@ async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot):
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
@
pytest
.
mark
.
private
async
def
test_flash_llama_awq_load
(
async
def
test_flash_llama_awq_load
(
flash_llama_awq
,
generate_load
,
response_snapshot
):
flash_llama_awq
,
generate_load
,
response_snapshot
):
responses
=
await
generate_load
(
responses
=
await
generate_load
(
flash_llama_awq
,
"What is Deep Learning?"
,
max_new_tokens
=
10
,
n
=
4
flash_llama_awq
,
"What is Deep Learning?"
,
max_new_tokens
=
10
,
n
=
4
)
)
assert
len
(
responses
)
==
4
assert
len
(
responses
)
==
4
assert
all
([
r
.
generated_text
==
"
\n
What is the difference between Deep Learning and Machine"
for
r
in
responses
])
assert
all
(
[
r
.
generated_text
==
"
\n
What is the difference between Deep Learning and Machine"
for
r
in
responses
]
)
assert
responses
==
response_snapshot
assert
responses
==
response_snapshot
integration-tests/models/test_flash_awq_sharded.py
View file @
47954b81
import
pytest
import
pytest
@
pytest
.
fixture
(
scope
=
"module"
)
@
pytest
.
fixture
(
scope
=
"module"
)
def
flash_llama_awq_handle_sharded
(
launcher
):
def
flash_llama_awq_handle_sharded
(
launcher
):
with
launcher
(
"abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq"
,
num_shard
=
2
,
quantize
=
"awq"
)
as
handle
:
with
launcher
(
"abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq"
,
num_shard
=
2
,
quantize
=
"awq"
,
)
as
handle
:
yield
handle
yield
handle
@
pytest
.
fixture
(
scope
=
"module"
)
@
pytest
.
fixture
(
scope
=
"module"
)
async
def
flash_llama_awq_sharded
(
flash_llama_awq_handle_sharded
):
async
def
flash_llama_awq_sharded
(
flash_llama_awq_handle_sharded
):
await
flash_llama_awq_handle_sharded
.
health
(
300
)
await
flash_llama_awq_handle_sharded
.
health
(
300
)
return
flash_llama_awq_handle_sharded
.
client
return
flash_llama_awq_handle_sharded
.
client
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
@
pytest
.
mark
.
private
async
def
test_flash_llama_awq_sharded
(
flash_llama_awq_sharded
,
response_snapshot
):
async
def
test_flash_llama_awq_sharded
(
flash_llama_awq_sharded
,
response_snapshot
):
...
@@ -18,9 +25,13 @@ async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapsho
...
@@ -18,9 +25,13 @@ async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapsho
)
)
assert
response
.
details
.
generated_tokens
==
10
assert
response
.
details
.
generated_tokens
==
10
assert
response
.
generated_text
==
"
\n
What is the difference between Deep Learning and Machine"
assert
(
response
.
generated_text
==
"
\n
What is the difference between Deep Learning and Machine"
)
assert
response
==
response_snapshot
assert
response
==
response_snapshot
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
@
pytest
.
mark
.
private
async
def
test_flash_llama_awq_load_sharded
(
async
def
test_flash_llama_awq_load_sharded
(
...
@@ -31,6 +42,12 @@ async def test_flash_llama_awq_load_sharded(
...
@@ -31,6 +42,12 @@ async def test_flash_llama_awq_load_sharded(
)
)
assert
len
(
responses
)
==
4
assert
len
(
responses
)
==
4
assert
all
([
r
.
generated_text
==
"
\n
What is the difference between Deep Learning and Machine"
for
r
in
responses
])
assert
all
(
[
r
.
generated_text
==
"
\n
What is the difference between Deep Learning and Machine"
for
r
in
responses
]
)
assert
responses
==
response_snapshot
assert
responses
==
response_snapshot
integration-tests/models/test_idefics.py
View file @
47954b81
...
@@ -3,9 +3,7 @@ import pytest
...
@@ -3,9 +3,7 @@ import pytest
@
pytest
.
fixture
(
scope
=
"module"
)
@
pytest
.
fixture
(
scope
=
"module"
)
def
idefics_handle
(
launcher
):
def
idefics_handle
(
launcher
):
with
launcher
(
with
launcher
(
"HuggingFaceM4/idefics-9b-instruct"
,
num_shard
=
2
)
as
handle
:
"HuggingFaceM4/idefics-9b-instruct"
,
num_shard
=
2
)
as
handle
:
yield
handle
yield
handle
...
...
server/tests/utils/test_tokens.py
View file @
47954b81
...
@@ -45,12 +45,15 @@ def test_stopping_criteria_max():
...
@@ -45,12 +45,15 @@ def test_stopping_criteria_max():
assert
criteria
(
1
,
""
)
==
(
False
,
None
)
assert
criteria
(
1
,
""
)
==
(
False
,
None
)
assert
criteria
(
1
,
""
)
==
(
True
,
FinishReason
.
FINISH_REASON_LENGTH
)
assert
criteria
(
1
,
""
)
==
(
True
,
FinishReason
.
FINISH_REASON_LENGTH
)
def
test_batch_top_tokens
():
def
test_batch_top_tokens
():
top_n_tokens
=
[
0
,
2
,
3
,
4
,
5
]
top_n_tokens
=
[
0
,
2
,
3
,
4
,
5
]
top_n_tokens_tensor
=
torch
.
tensor
(
top_n_tokens
)
top_n_tokens_tensor
=
torch
.
tensor
(
top_n_tokens
)
inp_logprobs
=
torch
.
tensor
([[
-
1.
,
-
3.
,
-
4.
,
-
2.
,
-
3.
]]
*
5
)
inp_logprobs
=
torch
.
tensor
([[
-
1.
0
,
-
3.
0
,
-
4.
0
,
-
2.
0
,
-
3.
0
]]
*
5
)
topn_tok_ids
,
topn_tok_logprobs
=
batch_top_tokens
(
top_n_tokens
,
top_n_tokens_tensor
,
inp_logprobs
)
topn_tok_ids
,
topn_tok_logprobs
=
batch_top_tokens
(
top_n_tokens
,
top_n_tokens_tensor
,
inp_logprobs
)
assert
topn_tok_ids
[
0
]
==
[]
assert
topn_tok_ids
[
0
]
==
[]
assert
topn_tok_ids
[
1
]
==
[
0
,
3
]
assert
topn_tok_ids
[
1
]
==
[
0
,
3
]
...
...
server/text_generation_server/cli.py
View file @
47954b81
...
@@ -125,8 +125,12 @@ def download_weights(
...
@@ -125,8 +125,12 @@ def download_weights(
if
not
is_local_model
:
if
not
is_local_model
:
try
:
try
:
adapter_config_filename
=
hf_hub_download
(
model_id
,
revision
=
revision
,
filename
=
"adapter_config.json"
)
adapter_config_filename
=
hf_hub_download
(
utils
.
download_and_unload_peft
(
model_id
,
revision
,
trust_remote_code
=
trust_remote_code
)
model_id
,
revision
=
revision
,
filename
=
"adapter_config.json"
)
utils
.
download_and_unload_peft
(
model_id
,
revision
,
trust_remote_code
=
trust_remote_code
)
is_local_model
=
True
is_local_model
=
True
utils
.
weight_files
(
model_id
,
revision
,
extension
)
utils
.
weight_files
(
model_id
,
revision
,
extension
)
return
return
...
@@ -179,11 +183,12 @@ def download_weights(
...
@@ -179,11 +183,12 @@ def download_weights(
import
transformers
import
transformers
import
json
import
json
if
is_local_model
:
if
is_local_model
:
config_filename
=
os
.
path
.
join
(
model_id
,
"config.json"
)
config_filename
=
os
.
path
.
join
(
model_id
,
"config.json"
)
else
:
else
:
config_filename
=
hf_hub_download
(
model_id
,
revision
=
revision
,
filename
=
"config.json"
)
config_filename
=
hf_hub_download
(
model_id
,
revision
=
revision
,
filename
=
"config.json"
)
with
open
(
config_filename
,
"r"
)
as
f
:
with
open
(
config_filename
,
"r"
)
as
f
:
config
=
json
.
load
(
f
)
config
=
json
.
load
(
f
)
architecture
=
config
[
"architectures"
][
0
]
architecture
=
config
[
"architectures"
][
0
]
...
...
server/text_generation_server/models/__init__.py
View file @
47954b81
...
@@ -153,7 +153,11 @@ def get_model(
...
@@ -153,7 +153,11 @@ def get_model(
)
)
elif
model_type
==
"mpt"
:
elif
model_type
==
"mpt"
:
return
MPTSharded
(
return
MPTSharded
(
model_id
,
revision
,
quantize
=
quantize
,
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
model_id
,
revision
,
quantize
=
quantize
,
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
)
)
elif
model_type
==
"gpt_neox"
:
elif
model_type
==
"gpt_neox"
:
...
@@ -269,13 +273,9 @@ def get_model(
...
@@ -269,13 +273,9 @@ def get_model(
"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
)
if
quantize
==
"awq"
:
if
quantize
==
"awq"
:
raise
ValueError
(
raise
ValueError
(
"awq quantization is not supported for AutoModel"
)
"awq quantization is not supported for AutoModel"
)
elif
(
quantize
==
"bitsandbytes-fp4"
)
or
(
quantize
==
"bitsandbytes-nf4"
):
elif
(
quantize
==
"bitsandbytes-fp4"
)
or
(
quantize
==
"bitsandbytes-nf4"
):
raise
ValueError
(
raise
ValueError
(
"4bit quantization is not supported for AutoModel"
)
"4bit quantization is not supported for AutoModel"
)
if
model_type
in
modeling_auto
.
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
:
if
model_type
in
modeling_auto
.
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
:
return
CausalLM
(
return
CausalLM
(
model_id
,
model_id
,
...
...
server/text_generation_server/models/causal_lm.py
View file @
47954b81
...
@@ -643,9 +643,12 @@ class CausalLM(Model):
...
@@ -643,9 +643,12 @@ class CausalLM(Model):
# Decode generated tokens
# Decode generated tokens
output_text
,
_
,
_
=
self
.
decode_token
(
output_text
,
_
,
_
=
self
.
decode_token
(
all_input_ids
[:,
0
],
all_input_ids
[:,
0
],
prefix_offset
=
len
(
all_input_ids
)
-
stopping_criteria
.
current_tokens
-
1
,
prefix_offset
=
len
(
all_input_ids
)
read_offset
=
len
(
all_input_ids
)
-
stopping_criteria
.
current_tokens
,
-
stopping_criteria
.
current_tokens
skip_special_tokens
=
True
-
1
,
read_offset
=
len
(
all_input_ids
)
-
stopping_criteria
.
current_tokens
,
skip_special_tokens
=
True
,
)
)
# Get seed
# Get seed
if
isinstance
(
next_token_chooser
.
choice
,
Sampling
):
if
isinstance
(
next_token_chooser
.
choice
,
Sampling
):
...
...
server/text_generation_server/models/custom_modeling/bloom_modeling.py
View file @
47954b81
...
@@ -40,7 +40,10 @@ from text_generation_server.utils.layers import (
...
@@ -40,7 +40,10 @@ from text_generation_server.utils.layers import (
)
)
CUSTOM_KERNELS_ENABLED
=
False
CUSTOM_KERNELS_ENABLED
=
False
if
torch
.
cuda
.
is_available
()
and
not
os
.
environ
.
get
(
"DISABLE_CUSTOM_KERNELS"
,
"False"
)
==
"True"
:
if
(
torch
.
cuda
.
is_available
()
and
not
os
.
environ
.
get
(
"DISABLE_CUSTOM_KERNELS"
,
"False"
)
==
"True"
):
try
:
try
:
from
custom_kernels
import
fused_bloom_attention_cuda
from
custom_kernels
import
fused_bloom_attention_cuda
...
...
server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
View file @
47954b81
...
@@ -169,6 +169,7 @@ def load_attention(config, prefix, weights):
...
@@ -169,6 +169,7 @@ def load_attention(config, prefix, weights):
bias
=
False
,
bias
=
False
,
)
)
def
_load_gqa
(
config
,
prefix
:
str
,
weights
):
def
_load_gqa
(
config
,
prefix
:
str
,
weights
):
assert
config
.
hidden_size
%
config
.
num_attention_heads
==
0
assert
config
.
hidden_size
%
config
.
num_attention_heads
==
0
assert
config
.
num_attention_heads
%
weights
.
process_group
.
size
()
==
0
assert
config
.
num_attention_heads
%
weights
.
process_group
.
size
()
==
0
...
@@ -211,7 +212,10 @@ class FlashLlamaAttention(torch.nn.Module):
...
@@ -211,7 +212,10 @@ class FlashLlamaAttention(torch.nn.Module):
# config=config, prefix=f"{prefix}.rotary_emb", weights=weights
# config=config, prefix=f"{prefix}.rotary_emb", weights=weights
# )
# )
self
.
rotary_emb
=
PositionRotaryEmbedding
.
static
(
self
.
rotary_emb
=
PositionRotaryEmbedding
.
static
(
config
=
config
,
dim
=
self
.
head_size
,
base
=
config
.
rope_theta
,
device
=
weights
.
device
config
=
config
,
dim
=
self
.
head_size
,
base
=
config
.
rope_theta
,
device
=
weights
.
device
,
)
)
self
.
softmax_scale
=
self
.
head_size
**-
0.5
self
.
softmax_scale
=
self
.
head_size
**-
0.5
...
...
server/text_generation_server/models/custom_modeling/idefics_image_processing.py
View file @
47954b81
...
@@ -20,7 +20,12 @@ import numpy as np
...
@@ -20,7 +20,12 @@ import numpy as np
from
PIL
import
Image
from
PIL
import
Image
from
transformers.image_processing_utils
import
BaseImageProcessor
,
BatchFeature
from
transformers.image_processing_utils
import
BaseImageProcessor
,
BatchFeature
from
transformers.image_transforms
import
resize
,
to_channel_dimension_format
,
rescale
,
normalize
from
transformers.image_transforms
import
(
resize
,
to_channel_dimension_format
,
rescale
,
normalize
,
)
from
transformers.image_utils
import
(
from
transformers.image_utils
import
(
ChannelDimension
,
ChannelDimension
,
ImageInput
,
ImageInput
,
...
@@ -121,7 +126,11 @@ class IdeficsImageProcessor(BaseImageProcessor):
...
@@ -121,7 +126,11 @@ class IdeficsImageProcessor(BaseImageProcessor):
a PyTorch tensor of the processed images
a PyTorch tensor of the processed images
"""
"""
image_size
=
image_size
if
image_size
is
not
None
else
self
.
image_size
image_size
=
image_size
if
image_size
is
not
None
else
self
.
image_size
image_num_channels
=
image_num_channels
if
image_num_channels
is
not
None
else
self
.
image_num_channels
image_num_channels
=
(
image_num_channels
if
image_num_channels
is
not
None
else
self
.
image_num_channels
)
image_mean
=
image_mean
if
image_mean
is
not
None
else
self
.
image_mean
image_mean
=
image_mean
if
image_mean
is
not
None
else
self
.
image_mean
image_std
=
image_std
if
image_std
is
not
None
else
self
.
image_std
image_std
=
image_std
if
image_std
is
not
None
else
self
.
image_std
size
=
(
image_size
,
image_size
)
size
=
(
image_size
,
image_size
)
...
@@ -160,9 +169,13 @@ class IdeficsImageProcessor(BaseImageProcessor):
...
@@ -160,9 +169,13 @@ class IdeficsImageProcessor(BaseImageProcessor):
images
=
[
resize
(
x
,
size
,
resample
=
PILImageResampling
.
BICUBIC
)
for
x
in
images
]
images
=
[
resize
(
x
,
size
,
resample
=
PILImageResampling
.
BICUBIC
)
for
x
in
images
]
images
=
[
self
.
rescale
(
image
=
image
,
scale
=
1
/
255
)
for
image
in
images
]
images
=
[
self
.
rescale
(
image
=
image
,
scale
=
1
/
255
)
for
image
in
images
]
images
=
[
self
.
normalize
(
x
,
mean
=
image_mean
,
std
=
image_std
)
for
x
in
images
]
images
=
[
self
.
normalize
(
x
,
mean
=
image_mean
,
std
=
image_std
)
for
x
in
images
]
images
=
[
to_channel_dimension_format
(
x
,
ChannelDimension
.
FIRST
)
for
x
in
images
]
images
=
[
to_channel_dimension_format
(
x
,
ChannelDimension
.
FIRST
)
for
x
in
images
]
# TODO: this converts to torch tensors - switch to convert_to_tensors once it becomes available
# TODO: this converts to torch tensors - switch to convert_to_tensors once it becomes available
images
=
BatchFeature
(
data
=
{
"pixel_values"
:
images
},
tensor_type
=
TensorType
.
PYTORCH
)[
"pixel_values"
]
images
=
BatchFeature
(
data
=
{
"pixel_values"
:
images
},
tensor_type
=
TensorType
.
PYTORCH
)[
"pixel_values"
]
return
images
return
images
...
@@ -185,7 +198,9 @@ class IdeficsImageProcessor(BaseImageProcessor):
...
@@ -185,7 +198,9 @@ class IdeficsImageProcessor(BaseImageProcessor):
response
.
raise_for_status
()
response
.
raise_for_status
()
return
Image
.
open
(
BytesIO
(
response
.
content
))
return
Image
.
open
(
BytesIO
(
response
.
content
))
else
:
else
:
raise
ValueError
(
f
"only a single or a list of entries is supported but got type=
{
type
(
image_url_or_urls
)
}
"
)
raise
ValueError
(
f
"only a single or a list of entries is supported but got type=
{
type
(
image_url_or_urls
)
}
"
)
def
rescale
(
def
rescale
(
self
,
self
,
...
@@ -255,10 +270,9 @@ class IdeficsImageProcessor(BaseImageProcessor):
...
@@ -255,10 +270,9 @@ class IdeficsImageProcessor(BaseImageProcessor):
`np.ndarray`: The normalized image.
`np.ndarray`: The normalized image.
"""
"""
# TODO 4.32
# TODO 4.32
return
normalize
(
return
normalize
(
image
,
mean
=
mean
,
std
=
std
,
data_format
=
data_format
,
**
kwargs
)
image
,
mean
=
mean
,
std
=
std
,
data_format
=
data_format
,
**
kwargs
)
import
transformers
import
transformers
transformers
.
IdeficsImageProcessor
=
IdeficsImageProcessor
transformers
.
IdeficsImageProcessor
=
IdeficsImageProcessor
server/text_generation_server/models/custom_modeling/idefics_modeling.py
View file @
47954b81
...
@@ -28,7 +28,11 @@ from torch.nn import CrossEntropyLoss
...
@@ -28,7 +28,11 @@ from torch.nn import CrossEntropyLoss
from
transformers
import
PreTrainedModel
from
transformers
import
PreTrainedModel
from
transformers.activations
import
ACT2FN
from
transformers.activations
import
ACT2FN
from
transformers.modeling_outputs
import
BaseModelOutputWithPast
,
CausalLMOutputWithPast
,
dataclass
from
transformers.modeling_outputs
import
(
BaseModelOutputWithPast
,
CausalLMOutputWithPast
,
dataclass
,
)
from
transformers.modeling_utils
import
PretrainedConfig
from
transformers.modeling_utils
import
PretrainedConfig
from
transformers.utils
import
(
from
transformers.utils
import
(
add_start_docstrings
,
add_start_docstrings
,
...
@@ -37,8 +41,12 @@ from transformers.utils import (
...
@@ -37,8 +41,12 @@ from transformers.utils import (
replace_return_docstrings
,
replace_return_docstrings
,
)
)
from
text_generation_server.models.custom_modeling.idefics_config
import
IdeficsConfig
from
text_generation_server.models.custom_modeling.idefics_config
import
IdeficsConfig
from
text_generation_server.models.custom_modeling.idefics_vision
import
IdeficsVisionTransformer
from
text_generation_server.models.custom_modeling.idefics_vision
import
(
from
text_generation_server.models.custom_modeling.idefics_perceiver
import
IdeficsPerceiverResampler
IdeficsVisionTransformer
,
)
from
text_generation_server.models.custom_modeling.idefics_perceiver
import
(
IdeficsPerceiverResampler
,
)
from
text_generation_server.utils.layers
import
(
from
text_generation_server.utils.layers
import
(
TensorParallelColumnLinear
,
TensorParallelColumnLinear
,
TensorParallelEmbedding
,
TensorParallelEmbedding
,
...
@@ -49,10 +57,12 @@ from text_generation_server.utils.layers import (
...
@@ -49,10 +57,12 @@ from text_generation_server.utils.layers import (
)
)
import
dropout_layer_norm
import
dropout_layer_norm
@
dataclass
@
dataclass
class
BaseModelOutputWithPastImage
(
BaseModelOutputWithPast
):
class
BaseModelOutputWithPastImage
(
BaseModelOutputWithPast
):
image_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
image_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
@
dataclass
@
dataclass
class
CausalLMOutputWithPastImage
(
CausalLMOutputWithPast
):
class
CausalLMOutputWithPastImage
(
CausalLMOutputWithPast
):
image_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
image_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
...
@@ -78,25 +88,39 @@ def expand_inputs_for_generation(
...
@@ -78,25 +88,39 @@ def expand_inputs_for_generation(
**
model_kwargs
,
**
model_kwargs
,
):
):
expanded_return_idx
=
(
expanded_return_idx
=
(
torch
.
arange
(
input_ids
.
shape
[
0
]).
view
(
-
1
,
1
).
repeat
(
1
,
expand_size
).
view
(
-
1
).
to
(
input_ids
.
device
)
torch
.
arange
(
input_ids
.
shape
[
0
])
.
view
(
-
1
,
1
)
.
repeat
(
1
,
expand_size
)
.
view
(
-
1
)
.
to
(
input_ids
.
device
)
)
)
input_ids
=
input_ids
.
index_select
(
0
,
expanded_return_idx
)
input_ids
=
input_ids
.
index_select
(
0
,
expanded_return_idx
)
if
"token_type_ids"
in
model_kwargs
:
if
"token_type_ids"
in
model_kwargs
:
token_type_ids
=
model_kwargs
[
"token_type_ids"
]
token_type_ids
=
model_kwargs
[
"token_type_ids"
]
model_kwargs
[
"token_type_ids"
]
=
token_type_ids
.
index_select
(
0
,
expanded_return_idx
)
model_kwargs
[
"token_type_ids"
]
=
token_type_ids
.
index_select
(
0
,
expanded_return_idx
)
if
attention_mask
is
not
None
:
if
attention_mask
is
not
None
:
model_kwargs
[
"attention_mask"
]
=
attention_mask
.
index_select
(
0
,
expanded_return_idx
)
model_kwargs
[
"attention_mask"
]
=
attention_mask
.
index_select
(
model_kwargs
[
"image_attention_mask"
]
=
model_kwargs
[
"image_attention_mask"
].
index_select
(
0
,
expanded_return_idx
)
model_kwargs
[
"image_attention_mask"
]
=
model_kwargs
[
"image_attention_mask"
].
index_select
(
0
,
expanded_return_idx
)
model_kwargs
[
"pixel_values"
]
=
model_kwargs
[
"pixel_values"
].
index_select
(
0
,
expanded_return_idx
0
,
expanded_return_idx
)
)
model_kwargs
[
"pixel_values"
]
=
model_kwargs
[
"pixel_values"
].
index_select
(
0
,
expanded_return_idx
)
if
is_encoder_decoder
:
if
is_encoder_decoder
:
if
encoder_outputs
is
None
:
if
encoder_outputs
is
None
:
raise
ValueError
(
"If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined."
)
raise
ValueError
(
encoder_outputs
[
"last_hidden_state"
]
=
encoder_outputs
.
last_hidden_state
.
index_select
(
"If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined."
)
encoder_outputs
[
"last_hidden_state"
]
=
encoder_outputs
.
last_hidden_state
.
index_select
(
0
,
expanded_return_idx
.
to
(
encoder_outputs
.
last_hidden_state
.
device
)
0
,
expanded_return_idx
.
to
(
encoder_outputs
.
last_hidden_state
.
device
)
)
)
model_kwargs
[
"encoder_outputs"
]
=
encoder_outputs
model_kwargs
[
"encoder_outputs"
]
=
encoder_outputs
...
@@ -120,14 +144,17 @@ def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder
...
@@ -120,14 +144,17 @@ def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder
# update token_type_ids with last value
# update token_type_ids with last value
if
"token_type_ids"
in
model_kwargs
:
if
"token_type_ids"
in
model_kwargs
:
token_type_ids
=
model_kwargs
[
"token_type_ids"
]
token_type_ids
=
model_kwargs
[
"token_type_ids"
]
model_kwargs
[
"token_type_ids"
]
=
torch
.
cat
([
token_type_ids
,
token_type_ids
[:,
-
1
].
unsqueeze
(
-
1
)],
dim
=-
1
)
model_kwargs
[
"token_type_ids"
]
=
torch
.
cat
(
[
token_type_ids
,
token_type_ids
[:,
-
1
].
unsqueeze
(
-
1
)],
dim
=-
1
)
# update attention masks
# update attention masks
if
not
is_encoder_decoder
:
if
not
is_encoder_decoder
:
if
"attention_mask"
in
model_kwargs
:
if
"attention_mask"
in
model_kwargs
:
attention_mask
=
model_kwargs
[
"attention_mask"
]
attention_mask
=
model_kwargs
[
"attention_mask"
]
model_kwargs
[
"attention_mask"
]
=
torch
.
cat
(
model_kwargs
[
"attention_mask"
]
=
torch
.
cat
(
[
attention_mask
,
attention_mask
.
new_ones
((
attention_mask
.
shape
[
0
],
1
))],
dim
=-
1
[
attention_mask
,
attention_mask
.
new_ones
((
attention_mask
.
shape
[
0
],
1
))],
dim
=-
1
,
)
)
if
"image_attention_mask"
in
model_kwargs
:
if
"image_attention_mask"
in
model_kwargs
:
image_attention_mask
=
model_kwargs
[
"image_attention_mask"
]
image_attention_mask
=
model_kwargs
[
"image_attention_mask"
]
...
@@ -180,8 +207,12 @@ def freeze_model(model, module_exceptions=[]):
...
@@ -180,8 +207,12 @@ def freeze_model(model, module_exceptions=[]):
}
}
module_exceptions_mapped
=
[
mapping
[
m
]
for
m
in
module_exceptions
]
module_exceptions_mapped
=
[
mapping
[
m
]
for
m
in
module_exceptions
]
for
module
in
model
.
modules
():
for
module
in
model
.
modules
():
if
module_exceptions
and
any
([
isinstance
(
module
,
t
)
for
t
in
module_exceptions_mapped
]):
if
module_exceptions
and
any
(
module
.
requires_grad_
(
True
)
# Explicitely setting it to true to avoid any mistakes
[
isinstance
(
module
,
t
)
for
t
in
module_exceptions_mapped
]
):
module
.
requires_grad_
(
True
)
# Explicitely setting it to true to avoid any mistakes
else
:
else
:
module
.
requires_grad_
(
False
)
module
.
requires_grad_
(
False
)
return
model
return
model
...
@@ -195,15 +226,21 @@ class IdeficsDecoupledPartialTPEmbedding(nn.Module):
...
@@ -195,15 +226,21 @@ class IdeficsDecoupledPartialTPEmbedding(nn.Module):
):
):
super
().
__init__
()
super
().
__init__
()
self
.
num_embeddings
=
config
.
vocab_size
self
.
num_embeddings
=
config
.
vocab_size
self
.
weight
=
TensorParallelEmbedding
(
prefix
=
"model.embed_tokens"
,
weights
=
weights
)
self
.
weight
=
TensorParallelEmbedding
(
self
.
additional_weight
=
nn
.
Parameter
(
weights
.
get_tensor
(
f
"model.embed_tokens.additional_embedding.weight"
))
prefix
=
"model.embed_tokens"
,
weights
=
weights
)
self
.
additional_weight
=
nn
.
Parameter
(
weights
.
get_tensor
(
f
"model.embed_tokens.additional_embedding.weight"
)
)
def
forward
(
self
,
input_ids
):
def
forward
(
self
,
input_ids
):
# Clone so that we don't modify the original input_ids later on
# Clone so that we don't modify the original input_ids later on
input_ids
=
input_ids
.
clone
()
input_ids
=
input_ids
.
clone
()
additional_vocab_indices
=
torch
.
where
(
input_ids
>=
self
.
num_embeddings
)
additional_vocab_indices
=
torch
.
where
(
input_ids
>=
self
.
num_embeddings
)
input_ids_additional_vocab
=
input_ids
[
additional_vocab_indices
]
input_ids_additional_vocab
=
input_ids
[
additional_vocab_indices
]
additional_embeddings
=
torch
.
nn
.
functional
.
embedding
(
input_ids_additional_vocab
-
self
.
num_embeddings
,
self
.
additional_weight
)
additional_embeddings
=
torch
.
nn
.
functional
.
embedding
(
input_ids_additional_vocab
-
self
.
num_embeddings
,
self
.
additional_weight
)
# for successful lookup replace input_ids with 0, the results of these will be discarded anyway
# for successful lookup replace input_ids with 0, the results of these will be discarded anyway
input_ids
[
additional_vocab_indices
]
=
0
input_ids
[
additional_vocab_indices
]
=
0
...
@@ -234,7 +271,10 @@ class IdeficsDecoupledTensorParallelLinear(nn.Module):
...
@@ -234,7 +271,10 @@ class IdeficsDecoupledTensorParallelLinear(nn.Module):
config
=
config
,
prefix
=
"lm_head"
,
weights
=
weights
config
=
config
,
prefix
=
"lm_head"
,
weights
=
weights
)
)
self
.
additional_fc
=
FastLinear
.
load
(
self
.
additional_fc
=
FastLinear
.
load
(
config
=
config
,
prefix
=
"lm_head.additional_fc"
,
weights
=
weights
,
bias
=
False
,
config
=
config
,
prefix
=
"lm_head.additional_fc"
,
weights
=
weights
,
bias
=
False
,
)
)
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -257,7 +297,10 @@ class IdeficsDecoupledTensorParallelLinear(nn.Module):
...
@@ -257,7 +297,10 @@ class IdeficsDecoupledTensorParallelLinear(nn.Module):
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def
_make_causal_mask
(
def
_make_causal_mask
(
input_ids_shape
:
torch
.
Size
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
past_key_values_length
:
int
=
0
input_ids_shape
:
torch
.
Size
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
past_key_values_length
:
int
=
0
,
):
):
"""
"""
Make causal mask used for bi-directional self-attention.
Make causal mask used for bi-directional self-attention.
...
@@ -269,8 +312,18 @@ def _make_causal_mask(
...
@@ -269,8 +312,18 @@ def _make_causal_mask(
mask
=
mask
.
to
(
dtype
)
mask
=
mask
.
to
(
dtype
)
if
past_key_values_length
>
0
:
if
past_key_values_length
>
0
:
mask
=
torch
.
cat
([
torch
.
zeros
(
tgt_len
,
past_key_values_length
,
dtype
=
dtype
,
device
=
device
),
mask
],
dim
=-
1
)
mask
=
torch
.
cat
(
return
mask
[
None
,
None
,
:,
:].
expand
(
bsz
,
1
,
tgt_len
,
tgt_len
+
past_key_values_length
)
[
torch
.
zeros
(
tgt_len
,
past_key_values_length
,
dtype
=
dtype
,
device
=
device
),
mask
,
],
dim
=-
1
,
)
return
mask
[
None
,
None
,
:,
:].
expand
(
bsz
,
1
,
tgt_len
,
tgt_len
+
past_key_values_length
)
def
_expand_mask
(
mask
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
tgt_len
:
Optional
[
int
]
=
None
):
def
_expand_mask
(
mask
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
tgt_len
:
Optional
[
int
]
=
None
):
...
@@ -284,7 +337,9 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
...
@@ -284,7 +337,9 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask
=
1.0
-
expanded_mask
inverted_mask
=
1.0
-
expanded_mask
return
inverted_mask
.
masked_fill
(
inverted_mask
.
to
(
torch
.
bool
),
torch
.
finfo
(
dtype
).
min
)
return
inverted_mask
.
masked_fill
(
inverted_mask
.
to
(
torch
.
bool
),
torch
.
finfo
(
dtype
).
min
)
class
IdeficsRMSNorm
(
nn
.
Module
):
class
IdeficsRMSNorm
(
nn
.
Module
):
...
@@ -346,7 +401,6 @@ class IdeficsRMSNorm(nn.Module):
...
@@ -346,7 +401,6 @@ class IdeficsRMSNorm(nn.Module):
if
unwrap
:
if
unwrap
:
normed_hidden_states
=
normed_hidden_states
.
view
(
*
shape
)
normed_hidden_states
=
normed_hidden_states
.
view
(
*
shape
)
return
normed_hidden_states
return
normed_hidden_states
...
@@ -367,7 +421,10 @@ class IdeficsMLP(nn.Module):
...
@@ -367,7 +421,10 @@ class IdeficsMLP(nn.Module):
bias
=
False
,
bias
=
False
,
)
)
self
.
down_proj
=
TensorParallelRowLinear
.
load
(
self
.
down_proj
=
TensorParallelRowLinear
.
load
(
config
,
prefix
=
f
"
{
prefix
}
.down_proj"
,
weights
=
weights
,
bias
=
False
,
config
,
prefix
=
f
"
{
prefix
}
.down_proj"
,
weights
=
weights
,
bias
=
False
,
)
)
self
.
act_fn
=
ACT2FN
[
config
.
hidden_act
]
self
.
act_fn
=
ACT2FN
[
config
.
hidden_act
]
...
@@ -375,7 +432,9 @@ class IdeficsMLP(nn.Module):
...
@@ -375,7 +432,9 @@ class IdeficsMLP(nn.Module):
gate_up_states
=
self
.
gate_up_proj
(
hidden_states
)
gate_up_states
=
self
.
gate_up_proj
(
hidden_states
)
shape
=
gate_up_states
.
shape
shape
=
gate_up_states
.
shape
gate_up_states
=
gate_up_states
.
view
(
*
shape
[:
-
1
],
2
,
shape
[
-
1
]
//
2
)
gate_up_states
=
gate_up_states
.
view
(
*
shape
[:
-
1
],
2
,
shape
[
-
1
]
//
2
)
return
self
.
down_proj
(
self
.
act_fn
(
gate_up_states
[:,
:,
0
])
*
gate_up_states
[:,
:,
1
])
return
self
.
down_proj
(
self
.
act_fn
(
gate_up_states
[:,
:,
0
])
*
gate_up_states
[:,
:,
1
]
)
# this was adapted from LlamaAttention
# this was adapted from LlamaAttention
...
@@ -445,14 +504,22 @@ class IdeficsAttention(nn.Module):
...
@@ -445,14 +504,22 @@ class IdeficsAttention(nn.Module):
self
.
qk_layer_norms
=
qk_layer_norms
self
.
qk_layer_norms
=
qk_layer_norms
if
self
.
qk_layer_norms
:
if
self
.
qk_layer_norms
:
self
.
q_layer_norm
=
IdeficsRMSNorm
(
self
.
q_layer_norm
=
IdeficsRMSNorm
(
prefix
=
f
"
{
prefix
}
.q_layer_norm"
,
weights
=
weights
,
eps
=
config
.
rms_norm_eps
prefix
=
f
"
{
prefix
}
.q_layer_norm"
,
weights
=
weights
,
eps
=
config
.
rms_norm_eps
,
)
)
self
.
k_layer_norm
=
IdeficsRMSNorm
(
self
.
k_layer_norm
=
IdeficsRMSNorm
(
prefix
=
f
"
{
prefix
}
.q_layer_norm"
,
weights
=
weights
,
eps
=
config
.
rms_norm_eps
prefix
=
f
"
{
prefix
}
.q_layer_norm"
,
weights
=
weights
,
eps
=
config
.
rms_norm_eps
,
)
)
def
_shape
(
self
,
tensor
:
torch
.
Tensor
,
seq_len
:
int
,
bsz
:
int
):
def
_shape
(
self
,
tensor
:
torch
.
Tensor
,
seq_len
:
int
,
bsz
:
int
):
return
tensor
.
view
(
bsz
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
).
contiguous
()
return
(
tensor
.
view
(
bsz
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
.
contiguous
()
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -470,20 +537,42 @@ class IdeficsAttention(nn.Module):
...
@@ -470,20 +537,42 @@ class IdeficsAttention(nn.Module):
bsz
,
q_len
,
_
=
hidden_states
.
size
()
bsz
,
q_len
,
_
=
hidden_states
.
size
()
if
is_cross_attention
:
if
is_cross_attention
:
query_states
=
self
.
q_proj
(
hidden_states
).
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
)
# .transpose(1, 2)
query_states
=
self
.
q_proj
(
hidden_states
).
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
)
# .transpose(1, 2)
query_states
=
query_states
.
transpose
(
1
,
2
)
query_states
=
query_states
.
transpose
(
1
,
2
)
_
,
kv_len
,
_
=
key_value_states
.
size
()
# Note that, in this case, `kv_len` == `kv_seq_len`
(
key_states
=
self
.
k_proj
(
key_value_states
).
view
(
bsz
,
kv_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
_
,
kv_len
,
_
,
)
=
(
key_value_states
.
size
()
)
# Note that, in this case, `kv_len` == `kv_seq_len`
key_states
=
(
self
.
k_proj
(
key_value_states
)
.
view
(
bsz
,
kv_len
,
self
.
num_heads
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
)
value_states
=
(
value_states
=
(
self
.
v_proj
(
key_value_states
).
view
(
bsz
,
kv_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
self
.
v_proj
(
key_value_states
)
.
view
(
bsz
,
kv_len
,
self
.
num_heads
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
)
)
else
:
else
:
qkv
=
self
.
qkv
(
hidden_states
)
qkv
=
self
.
qkv
(
hidden_states
)
query_states
,
key_states
,
value_states
=
qkv
.
split
(
self
.
num_heads
*
self
.
head_dim
,
dim
=
2
)
query_states
,
key_states
,
value_states
=
qkv
.
split
(
self
.
num_heads
*
self
.
head_dim
,
dim
=
2
)
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
)
# .transpose(1, 2)
query_states
=
query_states
.
view
(
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
)
# . transpose(1, 2)
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
)
# .transpose(1, 2)
)
# .transpose(1, 2)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
)
# . transpose(1, 2)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
)
# .transpose(1, 2)
kv_seq_len
=
q_len
kv_seq_len
=
q_len
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
kv_seq_len
+=
past_key_value
[
0
].
shape
[
-
2
]
kv_seq_len
+=
past_key_value
[
0
].
shape
[
-
2
]
...
@@ -493,10 +582,14 @@ class IdeficsAttention(nn.Module):
...
@@ -493,10 +582,14 @@ class IdeficsAttention(nn.Module):
)
)
shape
=
query_states
.
shape
shape
=
query_states
.
shape
query_states
=
self
.
rotary_emb
(
query_states
.
view
(
-
1
,
*
shape
[
2
:]),
cos
,
sin
).
view
(
shape
)
query_states
=
self
.
rotary_emb
(
query_states
.
view
(
-
1
,
*
shape
[
2
:]),
cos
,
sin
).
view
(
shape
)
shape
=
key_states
.
shape
shape
=
key_states
.
shape
key_states
=
self
.
rotary_emb
(
key_states
.
reshape
(
-
1
,
*
shape
[
2
:]),
cos
,
sin
).
view
(
shape
)
key_states
=
self
.
rotary_emb
(
key_states
.
reshape
(
-
1
,
*
shape
[
2
:]),
cos
,
sin
).
view
(
shape
)
query_states
=
query_states
.
transpose
(
1
,
2
)
query_states
=
query_states
.
transpose
(
1
,
2
)
key_states
=
key_states
.
transpose
(
1
,
2
)
key_states
=
key_states
.
transpose
(
1
,
2
)
...
@@ -571,8 +664,14 @@ class IdeficsDecoderLayer(nn.Module):
...
@@ -571,8 +664,14 @@ class IdeficsDecoderLayer(nn.Module):
prefix
=
f
"
{
prefix
}
.mlp"
,
prefix
=
f
"
{
prefix
}
.mlp"
,
weights
=
weights
,
weights
=
weights
,
)
)
self
.
input_layernorm
=
IdeficsRMSNorm
(
prefix
=
f
"
{
prefix
}
.input_layernorm"
,
weights
=
weights
,
eps
=
config
.
rms_norm_eps
)
self
.
input_layernorm
=
IdeficsRMSNorm
(
self
.
post_attention_layernorm
=
IdeficsRMSNorm
(
prefix
=
f
"
{
prefix
}
.post_attention_layernorm"
,
weights
=
weights
,
eps
=
config
.
rms_norm_eps
)
prefix
=
f
"
{
prefix
}
.input_layernorm"
,
weights
=
weights
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
IdeficsRMSNorm
(
prefix
=
f
"
{
prefix
}
.post_attention_layernorm"
,
weights
=
weights
,
eps
=
config
.
rms_norm_eps
,
)
self
.
dropout
=
config
.
dropout
self
.
dropout
=
config
.
dropout
def
forward
(
def
forward
(
...
@@ -583,7 +682,9 @@ class IdeficsDecoderLayer(nn.Module):
...
@@ -583,7 +682,9 @@ class IdeficsDecoderLayer(nn.Module):
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
output_attentions
:
Optional
[
bool
]
=
False
,
output_attentions
:
Optional
[
bool
]
=
False
,
use_cache
:
Optional
[
bool
]
=
False
,
use_cache
:
Optional
[
bool
]
=
False
,
)
->
Tuple
[
torch
.
FloatTensor
,
Optional
[
Tuple
[
torch
.
FloatTensor
,
torch
.
FloatTensor
]]]:
)
->
Tuple
[
torch
.
FloatTensor
,
Optional
[
Tuple
[
torch
.
FloatTensor
,
torch
.
FloatTensor
]]
]:
"""
"""
Args:
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
...
@@ -650,14 +751,22 @@ class IdeficsGatedCrossAttentionLayer(nn.Module):
...
@@ -650,14 +751,22 @@ class IdeficsGatedCrossAttentionLayer(nn.Module):
prefix
=
f
"
{
prefix
}
.mlp"
,
prefix
=
f
"
{
prefix
}
.mlp"
,
weights
=
weights
,
weights
=
weights
,
)
)
self
.
input_layernorm
=
IdeficsRMSNorm
(
prefix
=
f
"
{
prefix
}
.input_layernorm"
,
weights
=
weights
,
eps
=
config
.
rms_norm_eps
)
self
.
input_layernorm
=
IdeficsRMSNorm
(
self
.
post_attention_layernorm
=
IdeficsRMSNorm
(
prefix
=
f
"
{
prefix
}
.post_attention_layernorm"
,
weights
=
weights
,
eps
=
config
.
rms_norm_eps
)
prefix
=
f
"
{
prefix
}
.input_layernorm"
,
weights
=
weights
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
IdeficsRMSNorm
(
prefix
=
f
"
{
prefix
}
.post_attention_layernorm"
,
weights
=
weights
,
eps
=
config
.
rms_norm_eps
,
)
self
.
config
=
config
.
dropout
self
.
config
=
config
.
dropout
self
.
act_cross_attn
=
nn
.
Tanh
()
self
.
act_cross_attn
=
nn
.
Tanh
()
self
.
act_dense
=
nn
.
Tanh
()
self
.
act_dense
=
nn
.
Tanh
()
self
.
alpha_cross_attn
=
nn
.
Parameter
(
weights
.
get_tensor
(
f
"
{
prefix
}
.alpha_cross_attn"
))
self
.
alpha_cross_attn
=
nn
.
Parameter
(
weights
.
get_tensor
(
f
"
{
prefix
}
.alpha_cross_attn"
)
)
self
.
alpha_dense
=
nn
.
Parameter
(
weights
.
get_tensor
(
f
"
{
prefix
}
.alpha_dense"
))
self
.
alpha_dense
=
nn
.
Parameter
(
weights
.
get_tensor
(
f
"
{
prefix
}
.alpha_dense"
))
if
not
(
hasattr
(
self
,
"alpha_cross_attn"
)
and
hasattr
(
self
,
"alpha_dense"
)):
if
not
(
hasattr
(
self
,
"alpha_cross_attn"
)
and
hasattr
(
self
,
"alpha_dense"
)):
...
@@ -673,7 +782,9 @@ class IdeficsGatedCrossAttentionLayer(nn.Module):
...
@@ -673,7 +782,9 @@ class IdeficsGatedCrossAttentionLayer(nn.Module):
use_cache
:
Optional
[
bool
]
=
False
,
use_cache
:
Optional
[
bool
]
=
False
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
no_images
:
Optional
[
bool
]
=
False
,
no_images
:
Optional
[
bool
]
=
False
,
)
->
Tuple
[
torch
.
FloatTensor
,
Optional
[
Tuple
[
torch
.
FloatTensor
,
torch
.
FloatTensor
]]]:
)
->
Tuple
[
torch
.
FloatTensor
,
Optional
[
Tuple
[
torch
.
FloatTensor
,
torch
.
FloatTensor
]]
]:
"""
"""
Args:
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
...
@@ -695,7 +806,9 @@ class IdeficsGatedCrossAttentionLayer(nn.Module):
...
@@ -695,7 +806,9 @@ class IdeficsGatedCrossAttentionLayer(nn.Module):
)
)
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
raise
NotImplementedError
(
"Past key value states are not implemented for Idefics cross attention module."
)
raise
NotImplementedError
(
"Past key value states are not implemented for Idefics cross attention module."
)
residual
=
hidden_states
residual
=
hidden_states
...
@@ -711,7 +824,9 @@ class IdeficsGatedCrossAttentionLayer(nn.Module):
...
@@ -711,7 +824,9 @@ class IdeficsGatedCrossAttentionLayer(nn.Module):
# hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training)
# hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training)
# when there are no images the model is used in pure language mode
# when there are no images the model is used in pure language mode
gate
=
0
if
no_images
else
1
gate
=
0
if
no_images
else
1
hidden_states
=
residual
+
gate
*
self
.
act_cross_attn
(
self
.
alpha_cross_attn
)
*
hidden_states
hidden_states
=
(
residual
+
gate
*
self
.
act_cross_attn
(
self
.
alpha_cross_attn
)
*
hidden_states
)
# Fully Connected
# Fully Connected
residual
=
hidden_states
residual
=
hidden_states
...
@@ -896,11 +1011,14 @@ class IdeficsModel(IdeficsPreTrainedModel):
...
@@ -896,11 +1011,14 @@ class IdeficsModel(IdeficsPreTrainedModel):
self
.
gated_cross_attn_layers
=
nn
.
ModuleList
(
self
.
gated_cross_attn_layers
=
nn
.
ModuleList
(
[
[
IdeficsGatedCrossAttentionLayer
(
layer_id
,
config
,
weights
)
IdeficsGatedCrossAttentionLayer
(
layer_id
,
config
,
weights
)
for
layer_id
in
range
(
num_cross_layers
)]
for
layer_id
in
range
(
num_cross_layers
)
]
)
)
# self.gradient_checkpointing = False
# self.gradient_checkpointing = False
self
.
norm
=
IdeficsRMSNorm
(
prefix
=
f
"model.norm"
,
weights
=
weights
,
eps
=
config
.
rms_norm_eps
)
self
.
norm
=
IdeficsRMSNorm
(
prefix
=
f
"model.norm"
,
weights
=
weights
,
eps
=
config
.
rms_norm_eps
)
# self.gradient_checkpointing = False
# self.gradient_checkpointing = False
# Initialize weights and apply final processing
# Initialize weights and apply final processing
...
@@ -932,7 +1050,9 @@ class IdeficsModel(IdeficsPreTrainedModel):
...
@@ -932,7 +1050,9 @@ class IdeficsModel(IdeficsPreTrainedModel):
# self.embed_tokens = value
# self.embed_tokens = value
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def
_prepare_decoder_attention_mask
(
self
,
attention_mask
,
input_shape
,
inputs_embeds
,
past_key_values_length
):
def
_prepare_decoder_attention_mask
(
self
,
attention_mask
,
input_shape
,
inputs_embeds
,
past_key_values_length
):
# create causal mask
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask
=
None
combined_attention_mask
=
None
...
@@ -946,11 +1066,13 @@ class IdeficsModel(IdeficsPreTrainedModel):
...
@@ -946,11 +1066,13 @@ class IdeficsModel(IdeficsPreTrainedModel):
if
attention_mask
is
not
None
:
if
attention_mask
is
not
None
:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask
=
_expand_mask
(
attention_mask
,
inputs_embeds
.
dtype
,
tgt_len
=
input_shape
[
-
1
]).
to
(
expanded_attn_mask
=
_expand_mask
(
inputs_embeds
.
device
attention_mask
,
inputs_embeds
.
dtype
,
tgt_len
=
input_shape
[
-
1
]
)
)
.
to
(
inputs_embeds
.
device
)
combined_attention_mask
=
(
combined_attention_mask
=
(
expanded_attn_mask
if
combined_attention_mask
is
None
else
expanded_attn_mask
+
combined_attention_mask
expanded_attn_mask
if
combined_attention_mask
is
None
else
expanded_attn_mask
+
combined_attention_mask
)
)
return
combined_attention_mask
return
combined_attention_mask
...
@@ -974,23 +1096,35 @@ class IdeficsModel(IdeficsPreTrainedModel):
...
@@ -974,23 +1096,35 @@ class IdeficsModel(IdeficsPreTrainedModel):
)
->
Union
[
Tuple
,
BaseModelOutputWithPastImage
]:
)
->
Union
[
Tuple
,
BaseModelOutputWithPastImage
]:
device
=
input_ids
.
device
if
input_ids
is
not
None
else
inputs_embeds
.
device
device
=
input_ids
.
device
if
input_ids
is
not
None
else
inputs_embeds
.
device
output_attentions
=
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
output_attentions
=
(
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
)
output_hidden_states
=
(
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
)
use_cache
=
use_cache
if
use_cache
is
not
None
else
self
.
config
.
use_cache
use_cache
=
use_cache
if
use_cache
is
not
None
else
self
.
config
.
use_cache
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
return_dict
=
(
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
)
# retrieve input_ids and inputs_embeds
# retrieve input_ids and inputs_embeds
if
input_ids
is
not
None
and
inputs_embeds
is
not
None
:
if
input_ids
is
not
None
and
inputs_embeds
is
not
None
:
raise
ValueError
(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
raise
ValueError
(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
elif
input_ids
is
not
None
:
elif
input_ids
is
not
None
:
batch_size
,
seq_length
=
input_ids
.
shape
batch_size
,
seq_length
=
input_ids
.
shape
elif
inputs_embeds
is
not
None
:
elif
inputs_embeds
is
not
None
:
batch_size
,
seq_length
,
_
=
inputs_embeds
.
shape
batch_size
,
seq_length
,
_
=
inputs_embeds
.
shape
else
:
else
:
raise
ValueError
(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
raise
ValueError
(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
seq_length_with_past
=
seq_length
seq_length_with_past
=
seq_length
past_key_values_length
=
0
past_key_values_length
=
0
...
@@ -1006,7 +1140,10 @@ class IdeficsModel(IdeficsPreTrainedModel):
...
@@ -1006,7 +1140,10 @@ class IdeficsModel(IdeficsPreTrainedModel):
elif
position_ids
is
None
:
elif
position_ids
is
None
:
device
=
input_ids
.
device
if
input_ids
is
not
None
else
inputs_embeds
.
device
device
=
input_ids
.
device
if
input_ids
is
not
None
else
inputs_embeds
.
device
position_ids
=
torch
.
arange
(
position_ids
=
torch
.
arange
(
past_key_values_length
,
seq_length
+
past_key_values_length
,
dtype
=
torch
.
long
,
device
=
device
past_key_values_length
,
seq_length
+
past_key_values_length
,
dtype
=
torch
.
long
,
device
=
device
,
)
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
view
(
-
1
,
seq_length
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
view
(
-
1
,
seq_length
)
else
:
else
:
...
@@ -1016,29 +1153,52 @@ class IdeficsModel(IdeficsPreTrainedModel):
...
@@ -1016,29 +1153,52 @@ class IdeficsModel(IdeficsPreTrainedModel):
if
image_hidden_states
is
None
:
if
image_hidden_states
is
None
:
if
pixel_values
is
None
and
image_embeddings
is
None
:
if
pixel_values
is
None
and
image_embeddings
is
None
:
raise
ValueError
(
"Either pixel_values and image_embeddings have to be not-None."
)
raise
ValueError
(
"Either pixel_values and image_embeddings have to be not-None."
)
elif
pixel_values
is
not
None
and
image_embeddings
is
not
None
:
elif
pixel_values
is
not
None
and
image_embeddings
is
not
None
:
raise
ValueError
(
"You cannot specify both pixel_values and image_embeddings at the same time"
)
raise
ValueError
(
"You cannot specify both pixel_values and image_embeddings at the same time"
)
elif
pixel_values
is
not
None
:
elif
pixel_values
is
not
None
:
no_images
=
len
(
torch
.
nonzero
(
pixel_values
))
==
0
no_images
=
len
(
torch
.
nonzero
(
pixel_values
))
==
0
pixel_values
=
pixel_values
.
to
(
dtype
=
self
.
dtype
,
device
=
device
)
# fp16 compatibility
pixel_values
=
pixel_values
.
to
(
dtype
=
self
.
dtype
,
device
=
device
)
# fp16 compatibility
batch_size
,
num_images
=
pixel_values
.
shape
[:
2
]
batch_size
,
num_images
=
pixel_values
.
shape
[:
2
]
pixel_values
=
pixel_values
.
contiguous
().
view
(
batch_size
*
num_images
,
*
pixel_values
.
shape
[
2
:])
pixel_values
=
pixel_values
.
contiguous
().
view
(
batch_size
*
num_images
,
*
pixel_values
.
shape
[
2
:]
)
# Get sequence from the vision encoder
# Get sequence from the vision encoder
image_hidden_states
=
self
.
vision_model
(
pixel_values
=
pixel_values
).
last_hidden_state
image_hidden_states
=
self
.
vision_model
(
pixel_values
=
pixel_values
).
last_hidden_state
elif
image_embeddings
is
not
None
:
elif
image_embeddings
is
not
None
:
batch_size
,
num_images
,
image_seq_len
,
image_hidden_size
=
image_embeddings
.
size
()
(
image_hidden_states
=
image_embeddings
.
to
(
dtype
=
self
.
dtype
,
device
=
input_ids
.
device
)
batch_size
,
image_hidden_states
=
image_hidden_states
.
view
(
batch_size
*
num_images
,
image_seq_len
,
image_hidden_size
)
num_images
,
image_seq_len
,
image_hidden_size
,
)
=
image_embeddings
.
size
()
image_hidden_states
=
image_embeddings
.
to
(
dtype
=
self
.
dtype
,
device
=
input_ids
.
device
)
image_hidden_states
=
image_hidden_states
.
view
(
batch_size
*
num_images
,
image_seq_len
,
image_hidden_size
)
if
self
.
config
.
use_resampler
:
if
self
.
config
.
use_resampler
:
image_hidden_states
=
self
.
perceiver_resampler
(
image_hidden_states
)
image_hidden_states
=
self
.
perceiver_resampler
(
image_hidden_states
)
image_seq_len
,
image_hidden_size
=
image_hidden_states
.
size
(
1
),
image_hidden_states
.
size
(
2
)
image_seq_len
,
image_hidden_size
=
image_hidden_states
.
size
(
image_hidden_states
=
image_hidden_states
.
view
(
batch_size
,
num_images
*
image_seq_len
,
image_hidden_size
)
1
),
image_hidden_states
.
size
(
2
)
image_hidden_states
=
image_hidden_states
.
view
(
batch_size
,
num_images
*
image_seq_len
,
image_hidden_size
)
else
:
else
:
no_images
=
False
no_images
=
False
num_images
=
pixel_values
.
shape
[
1
]
num_images
=
pixel_values
.
shape
[
1
]
...
@@ -1050,7 +1210,9 @@ class IdeficsModel(IdeficsPreTrainedModel):
...
@@ -1050,7 +1210,9 @@ class IdeficsModel(IdeficsPreTrainedModel):
text_seq_len
=
image_attention_mask
.
size
(
1
)
text_seq_len
=
image_attention_mask
.
size
(
1
)
image_attention_mask
=
image_attention_mask
.
unsqueeze
(
-
1
)
image_attention_mask
=
image_attention_mask
.
unsqueeze
(
-
1
)
image_attention_mask
=
image_attention_mask
.
repeat
(
1
,
1
,
1
,
image_seq_len
)
image_attention_mask
=
image_attention_mask
.
repeat
(
1
,
1
,
1
,
image_seq_len
)
image_attention_mask
=
image_attention_mask
.
view
(
batch_size
,
text_seq_len
,
num_images
*
image_seq_len
)
image_attention_mask
=
image_attention_mask
.
view
(
batch_size
,
text_seq_len
,
num_images
*
image_seq_len
)
image_batch_size
,
image_sequence_length
,
_
=
image_hidden_states
.
size
()
image_batch_size
,
image_sequence_length
,
_
=
image_hidden_states
.
size
()
image_hidden_shape
=
(
image_batch_size
,
image_sequence_length
)
image_hidden_shape
=
(
image_batch_size
,
image_sequence_length
)
if
image_attention_mask
is
None
:
if
image_attention_mask
is
None
:
...
@@ -1060,7 +1222,6 @@ class IdeficsModel(IdeficsPreTrainedModel):
...
@@ -1060,7 +1222,6 @@ class IdeficsModel(IdeficsPreTrainedModel):
# if list(image_attention_mask.shape) != [4, 1, 1024, 64]:
# if list(image_attention_mask.shape) != [4, 1, 1024, 64]:
# raise ValueError(f"Image hidden_states {image_hidden_states.shape} - mask {image_attention_mask.shape} {num_images} {image_seq_len} {text_seq_len}")
# raise ValueError(f"Image hidden_states {image_hidden_states.shape} - mask {image_attention_mask.shape} {num_images} {image_seq_len} {text_seq_len}")
# if image_hidden_states is not None:
# if image_hidden_states is not None:
# else:
# else:
# image_attention_mask = None
# image_attention_mask = None
...
@@ -1070,10 +1231,15 @@ class IdeficsModel(IdeficsPreTrainedModel):
...
@@ -1070,10 +1231,15 @@ class IdeficsModel(IdeficsPreTrainedModel):
# embed positions
# embed positions
if
attention_mask
is
None
:
if
attention_mask
is
None
:
attention_mask
=
torch
.
ones
(
attention_mask
=
torch
.
ones
(
(
batch_size
,
seq_length_with_past
),
dtype
=
torch
.
bool
,
device
=
inputs_embeds
.
device
(
batch_size
,
seq_length_with_past
),
dtype
=
torch
.
bool
,
device
=
inputs_embeds
.
device
,
)
)
attention_mask
=
self
.
_prepare_decoder_attention_mask
(
attention_mask
=
self
.
_prepare_decoder_attention_mask
(
attention_mask
,
(
batch_size
,
seq_length
),
inputs_embeds
,
past_key_values_length
attention_mask
,
(
batch_size
,
seq_length
),
inputs_embeds
,
past_key_values_length
,
)
)
hidden_states
=
inputs_embeds
hidden_states
=
inputs_embeds
...
@@ -1094,7 +1260,9 @@ class IdeficsModel(IdeficsPreTrainedModel):
...
@@ -1094,7 +1260,9 @@ class IdeficsModel(IdeficsPreTrainedModel):
if
output_hidden_states
:
if
output_hidden_states
:
all_hidden_states
+=
(
hidden_states
,)
all_hidden_states
+=
(
hidden_states
,)
past_key_value
=
past_key_values
[
idx
]
if
past_key_values
is
not
None
else
None
past_key_value
=
(
past_key_values
[
idx
]
if
past_key_values
is
not
None
else
None
)
def
vblock
(
def
vblock
(
main_block
,
main_block
,
...
@@ -1194,7 +1362,11 @@ class IdeficsModel(IdeficsPreTrainedModel):
...
@@ -1194,7 +1362,11 @@ class IdeficsModel(IdeficsPreTrainedModel):
next_cache
=
next_decoder_cache
if
use_cache
else
None
next_cache
=
next_decoder_cache
if
use_cache
else
None
if
not
return_dict
:
if
not
return_dict
:
return
tuple
(
v
for
v
in
[
hidden_states
,
next_cache
,
all_hidden_states
,
all_self_attns
]
if
v
is
not
None
)
return
tuple
(
v
for
v
in
[
hidden_states
,
next_cache
,
all_hidden_states
,
all_self_attns
]
if
v
is
not
None
)
return
BaseModelOutputWithPastImage
(
return
BaseModelOutputWithPastImage
(
last_hidden_state
=
hidden_states
,
last_hidden_state
=
hidden_states
,
past_key_values
=
next_cache
,
past_key_values
=
next_cache
,
...
@@ -1264,11 +1436,19 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
...
@@ -1264,11 +1436,19 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
```"""
```"""
output_attentions
=
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
output_attentions
=
(
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
)
output_hidden_states
=
(
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
return_dict
=
(
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
)
)
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs
=
self
.
model
(
outputs
=
self
.
model
(
...
@@ -1298,7 +1478,7 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
...
@@ -1298,7 +1478,7 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
past_key_values
=
outputs
.
past_key_values
,
past_key_values
=
outputs
.
past_key_values
,
hidden_states
=
outputs
.
hidden_states
,
hidden_states
=
outputs
.
hidden_states
,
attentions
=
outputs
.
attentions
,
attentions
=
outputs
.
attentions
,
image_hidden_states
=
outputs
.
image_hidden_states
image_hidden_states
=
outputs
.
image_hidden_states
,
)
)
def
prepare_inputs_for_generation
(
self
,
input_ids
,
past
=
None
,
**
kwargs
):
def
prepare_inputs_for_generation
(
self
,
input_ids
,
past
=
None
,
**
kwargs
):
...
@@ -1316,12 +1496,20 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
...
@@ -1316,12 +1496,20 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
return
expand_inputs_for_generation
(
*
args
,
**
model_kwargs
)
return
expand_inputs_for_generation
(
*
args
,
**
model_kwargs
)
@
staticmethod
@
staticmethod
def
_update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
=
False
):
def
_update_model_kwargs_for_generation
(
return
update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
=
is_encoder_decoder
)
outputs
,
model_kwargs
,
is_encoder_decoder
=
False
):
return
update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
=
is_encoder_decoder
)
@
staticmethod
@
staticmethod
def
_reorder_cache
(
past
,
beam_idx
):
def
_reorder_cache
(
past
,
beam_idx
):
reordered_past
=
()
reordered_past
=
()
for
layer_past
in
past
:
for
layer_past
in
past
:
reordered_past
+=
(
tuple
(
past_state
.
index_select
(
0
,
beam_idx
)
for
past_state
in
layer_past
),)
reordered_past
+=
(
tuple
(
past_state
.
index_select
(
0
,
beam_idx
)
for
past_state
in
layer_past
),
)
return
reordered_past
return
reordered_past
server/text_generation_server/models/custom_modeling/idefics_perceiver.py
View file @
47954b81
...
@@ -46,7 +46,8 @@ from text_generation_server.utils.layers import (
...
@@ -46,7 +46,8 @@ from text_generation_server.utils.layers import (
TensorParallelRowLinear
,
TensorParallelRowLinear
,
)
)
EPS
=
1e-5
EPS
=
1e-5
class
IdeficsPerceiverResampler
(
nn
.
Module
):
class
IdeficsPerceiverResampler
(
nn
.
Module
):
def
__init__
(
def
__init__
(
...
@@ -78,7 +79,12 @@ class IdeficsPerceiverResampler(nn.Module):
...
@@ -78,7 +79,12 @@ class IdeficsPerceiverResampler(nn.Module):
"""
"""
super
().
__init__
()
super
().
__init__
()
self
.
embed_dim
,
self
.
n_heads
,
self
.
head_dim
,
self
.
n_latents
=
embed_dim
,
n_heads
,
head_dim
,
n_latents
self
.
embed_dim
,
self
.
n_heads
,
self
.
head_dim
,
self
.
n_latents
=
(
embed_dim
,
n_heads
,
head_dim
,
n_latents
,
)
self
.
qk_layer_norms
=
config
.
perceiver_config
.
qk_layer_norms_perceiver
self
.
qk_layer_norms
=
config
.
perceiver_config
.
qk_layer_norms_perceiver
# Create Latents for Perceiver
# Create Latents for Perceiver
...
@@ -107,14 +113,16 @@ class IdeficsPerceiverResampler(nn.Module):
...
@@ -107,14 +113,16 @@ class IdeficsPerceiverResampler(nn.Module):
prefix
=
f
"
{
prefix
}
.blocks.
{
layer_id
}
.1"
,
prefix
=
f
"
{
prefix
}
.blocks.
{
layer_id
}
.1"
,
intermediate_size
=
self
.
intermediate_dim
,
intermediate_size
=
self
.
intermediate_dim
,
config
=
config
,
config
=
config
,
weights
=
weights
weights
=
weights
,
),
),
]
]
)
)
for
layer_id
in
range
(
depth
)
for
layer_id
in
range
(
depth
)
]
]
)
)
self
.
layer_norm
=
nn
.
LayerNorm
.
load
(
prefix
=
f
"
{
prefix
}
.layer_norm"
,
weights
=
weights
,
eps
=
EPS
)
self
.
layer_norm
=
nn
.
LayerNorm
.
load
(
prefix
=
f
"
{
prefix
}
.layer_norm"
,
weights
=
weights
,
eps
=
EPS
)
def
forward
(
self
,
context
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
context
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Resample arbitrary length context & *compress* down to self.n_latents latent embeddings"""
"""Resample arbitrary length context & *compress* down to self.n_latents latent embeddings"""
...
@@ -130,25 +138,34 @@ class IdeficsPerceiverResampler(nn.Module):
...
@@ -130,25 +138,34 @@ class IdeficsPerceiverResampler(nn.Module):
class
IdeficsPerceiverAttention
(
nn
.
Module
):
class
IdeficsPerceiverAttention
(
nn
.
Module
):
def
__init__
(
self
,
def
__init__
(
self
,
prefix
,
prefix
,
config
,
config
,
embed_dim
:
int
,
embed_dim
:
int
,
n_heads
:
int
,
n_heads
:
int
,
head_dim
:
int
,
head_dim
:
int
,
qk_layer_norms
:
bool
,
qk_layer_norms
:
bool
,
weights
weights
,
)
->
None
:
)
->
None
:
"""Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`"""
"""Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`"""
super
().
__init__
()
super
().
__init__
()
self
.
embed_dim
,
self
.
n_heads
,
self
.
head_dim
=
embed_dim
,
n_heads
,
head_dim
self
.
embed_dim
,
self
.
n_heads
,
self
.
head_dim
=
embed_dim
,
n_heads
,
head_dim
self
.
qk_layer_norms
=
qk_layer_norms
self
.
qk_layer_norms
=
qk_layer_norms
# Normalization & Scaling
# Normalization & Scaling
self
.
context_layer_norm
=
nn
.
LayerNorm
.
load
(
prefix
=
f
"
{
prefix
}
.context_layer_norm"
,
weights
=
weights
,
eps
=
EPS
)
self
.
context_layer_norm
=
nn
.
LayerNorm
.
load
(
self
.
latents_layer_norm
=
nn
.
LayerNorm
.
load
(
prefix
=
f
"
{
prefix
}
.latents_layer_norm"
,
weights
=
weights
,
eps
=
EPS
)
prefix
=
f
"
{
prefix
}
.context_layer_norm"
,
weights
=
weights
,
eps
=
EPS
)
self
.
latents_layer_norm
=
nn
.
LayerNorm
.
load
(
prefix
=
f
"
{
prefix
}
.latents_layer_norm"
,
weights
=
weights
,
eps
=
EPS
)
if
self
.
qk_layer_norms
:
if
self
.
qk_layer_norms
:
self
.
q_layer_norm
=
nn
.
LayerNorm
.
load
(
prefix
=
f
"
{
prefix
}
.q_layer_norm"
,
weights
=
weights
,
eps
=
EPS
)
self
.
q_layer_norm
=
nn
.
LayerNorm
.
load
(
self
.
k_layer_norm
=
nn
.
LayerNorm
.
load
(
prefix
=
f
"
{
prefix
}
.k_layer_norm"
,
weights
=
weights
,
eps
=
EPS
)
prefix
=
f
"
{
prefix
}
.q_layer_norm"
,
weights
=
weights
,
eps
=
EPS
)
self
.
k_layer_norm
=
nn
.
LayerNorm
.
load
(
prefix
=
f
"
{
prefix
}
.k_layer_norm"
,
weights
=
weights
,
eps
=
EPS
)
self
.
qk_scale
=
self
.
head_dim
**-
0.5
self
.
qk_scale
=
self
.
head_dim
**-
0.5
...
@@ -202,7 +219,12 @@ class IdeficsPerceiverAttention(nn.Module):
...
@@ -202,7 +219,12 @@ class IdeficsPerceiverAttention(nn.Module):
# Multiheaded Self-Attention w/ stable softmax (subtract per-row max -- `amax` -- before softmax call)
# Multiheaded Self-Attention w/ stable softmax (subtract per-row max -- `amax` -- before softmax call)
# =>> `attn` should be a 2D matrix of shape [n_latents x (context + n_latents)]
# =>> `attn` should be a 2D matrix of shape [n_latents x (context + n_latents)]
# einsum.rearrange(x, "bsz seq (heads embed) -> bsz heads seq embed", heads=self.n_heads)
# einsum.rearrange(x, "bsz seq (heads embed) -> bsz heads seq embed", heads=self.n_heads)
q
,
k
,
v
=
[
x
.
reshape
(
batch_size
,
x
.
shape
[
1
],
self
.
n_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
for
x
in
(
q
,
k
,
v
)]
q
,
k
,
v
=
[
x
.
reshape
(
batch_size
,
x
.
shape
[
1
],
self
.
n_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
for
x
in
(
q
,
k
,
v
)
]
if
self
.
qk_layer_norms
:
if
self
.
qk_layer_norms
:
q
=
self
.
q_layer_norm
(
q
)
q
=
self
.
q_layer_norm
(
q
)
...
@@ -219,7 +241,8 @@ class IdeficsPerceiverAttention(nn.Module):
...
@@ -219,7 +241,8 @@ class IdeficsPerceiverAttention(nn.Module):
class
IdeficsMLP
(
nn
.
Module
):
class
IdeficsMLP
(
nn
.
Module
):
def
__init__
(
self
,
def
__init__
(
self
,
prefix
,
prefix
,
intermediate_size
,
intermediate_size
,
config
,
config
,
...
@@ -230,14 +253,22 @@ class IdeficsMLP(nn.Module):
...
@@ -230,14 +253,22 @@ class IdeficsMLP(nn.Module):
self
.
embed_dim
=
config
.
vision_config
.
embed_dim
self
.
embed_dim
=
config
.
vision_config
.
embed_dim
self
.
ln
=
nn
.
LayerNorm
.
load
(
prefix
=
f
"
{
prefix
}
.ln"
,
weights
=
weights
,
eps
=
EPS
)
self
.
ln
=
nn
.
LayerNorm
.
load
(
prefix
=
f
"
{
prefix
}
.ln"
,
weights
=
weights
,
eps
=
EPS
)
self
.
fc
=
TensorParallelColumnLinear
.
load
(
self
.
fc
=
TensorParallelColumnLinear
.
load
(
config
=
config
,
prefix
=
f
"
{
prefix
}
.fc"
,
weights
=
weights
,
bias
=
False
,
config
=
config
,
prefix
=
f
"
{
prefix
}
.fc"
,
weights
=
weights
,
bias
=
False
,
)
)
self
.
act
=
nn
.
ReLU
()
self
.
act
=
nn
.
ReLU
()
self
.
c_proj
=
TensorParallelRowLinear
.
load
(
self
.
c_proj
=
TensorParallelRowLinear
.
load
(
config
=
config
,
prefix
=
f
"
{
prefix
}
.c_proj"
,
weights
=
weights
,
bias
=
False
,
config
=
config
,
prefix
=
f
"
{
prefix
}
.c_proj"
,
weights
=
weights
,
bias
=
False
,
)
)
def
forward
(
self
,
hidden_states
:
Optional
[
Tuple
[
torch
.
FloatTensor
]])
->
torch
.
FloatTensor
:
def
forward
(
self
,
hidden_states
:
Optional
[
Tuple
[
torch
.
FloatTensor
]]
)
->
torch
.
FloatTensor
:
hidden_states
=
self
.
ln
(
hidden_states
)
hidden_states
=
self
.
ln
(
hidden_states
)
hidden_states
=
self
.
fc
(
hidden_states
)
hidden_states
=
self
.
fc
(
hidden_states
)
hidden_states
=
self
.
act
(
hidden_states
)
hidden_states
=
self
.
act
(
hidden_states
)
...
...
server/text_generation_server/models/custom_modeling/idefics_processing.py
View file @
47954b81
...
@@ -21,9 +21,16 @@ from urllib.parse import urlparse
...
@@ -21,9 +21,16 @@ from urllib.parse import urlparse
from
transformers.feature_extraction_utils
import
BatchFeature
from
transformers.feature_extraction_utils
import
BatchFeature
from
transformers.processing_utils
import
ProcessorMixin
from
transformers.processing_utils
import
ProcessorMixin
from
transformers.tokenization_utils_base
import
BatchEncoding
,
PaddingStrategy
,
TextInput
,
TruncationStrategy
from
transformers.tokenization_utils_base
import
(
BatchEncoding
,
PaddingStrategy
,
TextInput
,
TruncationStrategy
,
)
from
transformers.utils
import
TensorType
,
is_torch_available
from
transformers.utils
import
TensorType
,
is_torch_available
from
text_generation_server.models.custom_modeling.idefics_image_processing
import
IdeficsImageProcessor
from
text_generation_server.models.custom_modeling.idefics_image_processing
import
(
IdeficsImageProcessor
,
)
if
is_torch_available
():
if
is_torch_available
():
...
@@ -124,7 +131,14 @@ class IdeficsProcessor(ProcessorMixin):
...
@@ -124,7 +131,14 @@ class IdeficsProcessor(ProcessorMixin):
image_processor_class
=
"IdeficsImageProcessor"
image_processor_class
=
"IdeficsImageProcessor"
tokenizer_class
=
"LlamaTokenizerFast"
tokenizer_class
=
"LlamaTokenizerFast"
def
__init__
(
self
,
image_processor
,
tokenizer
=
None
,
image_size
=
224
,
add_end_of_utterance_token
=
None
,
**
kwargs
):
def
__init__
(
self
,
image_processor
,
tokenizer
=
None
,
image_size
=
224
,
add_end_of_utterance_token
=
None
,
**
kwargs
,
):
if
image_processor
is
None
:
if
image_processor
is
None
:
raise
ValueError
(
"You need to specify an `image_processor`."
)
raise
ValueError
(
"You need to specify an `image_processor`."
)
if
tokenizer
is
None
:
if
tokenizer
is
None
:
...
@@ -142,7 +156,8 @@ class IdeficsProcessor(ProcessorMixin):
...
@@ -142,7 +156,8 @@ class IdeficsProcessor(ProcessorMixin):
self
.
tokenizer_was_trained_with_end_of_utterance_token
=
(
self
.
tokenizer_was_trained_with_end_of_utterance_token
=
(
True
True
if
"<end_of_utterance>"
in
self
.
tokenizer
.
special_tokens_map
.
get
(
"additional_special_tokens"
,
[])
if
"<end_of_utterance>"
in
self
.
tokenizer
.
special_tokens_map
.
get
(
"additional_special_tokens"
,
[])
else
False
else
False
)
)
...
@@ -265,7 +280,9 @@ class IdeficsProcessor(ProcessorMixin):
...
@@ -265,7 +280,9 @@ class IdeficsProcessor(ProcessorMixin):
# if the value isn't overriden by the user, check if the tokenizer was trained with this token and then use it
# if the value isn't overriden by the user, check if the tokenizer was trained with this token and then use it
if
add_end_of_utterance_token
is
None
:
if
add_end_of_utterance_token
is
None
:
add_end_of_utterance_token
=
self
.
tokenizer_was_trained_with_end_of_utterance_token
add_end_of_utterance_token
=
(
self
.
tokenizer_was_trained_with_end_of_utterance_token
)
# turn non-batched prompts into batched
# turn non-batched prompts into batched
if
not
any
(
isinstance
(
i
,
list
)
for
i
in
prompts
):
if
not
any
(
isinstance
(
i
,
list
)
for
i
in
prompts
):
...
@@ -358,10 +375,14 @@ class IdeficsProcessor(ProcessorMixin):
...
@@ -358,10 +375,14 @@ class IdeficsProcessor(ProcessorMixin):
current_images
=
images
[:
local_max_num_images
]
current_images
=
images
[:
local_max_num_images
]
if
len
(
current_images
)
>
0
:
if
len
(
current_images
)
>
0
:
padded_image_tensor
=
torch
.
zeros
(
max_num_images
,
*
current_images
.
size
()[
1
:])
padded_image_tensor
=
torch
.
zeros
(
max_num_images
,
*
current_images
.
size
()[
1
:]
)
padded_image_tensor
[:
current_images
.
size
(
0
)]
=
current_images
padded_image_tensor
[:
current_images
.
size
(
0
)]
=
current_images
else
:
else
:
padded_image_tensor
=
torch
.
zeros
(
max_num_images
,
*
self
.
default_image_dims
)
padded_image_tensor
=
torch
.
zeros
(
max_num_images
,
*
self
.
default_image_dims
)
output_images
.
append
(
padded_image_tensor
)
output_images
.
append
(
padded_image_tensor
)
output_input_ids
.
append
(
torch
.
tensor
(
padded_input_ids
))
output_input_ids
.
append
(
torch
.
tensor
(
padded_input_ids
))
...
@@ -373,14 +394,19 @@ class IdeficsProcessor(ProcessorMixin):
...
@@ -373,14 +394,19 @@ class IdeficsProcessor(ProcessorMixin):
output_attention_masks
=
torch
.
stack
(
output_attention_masks
)
output_attention_masks
=
torch
.
stack
(
output_attention_masks
)
if
at_least_one_image
:
if
at_least_one_image
:
image_attention_mask
,
_
=
image_attention_mask_for_packed_input_ids
(
output_input_ids
,
self
.
tokenizer
)
image_attention_mask
,
_
=
image_attention_mask_for_packed_input_ids
(
output_input_ids
,
self
.
tokenizer
)
image_attention_mask
=
incremental_to_binary_attention_mask
(
image_attention_mask
=
incremental_to_binary_attention_mask
(
image_attention_mask
,
num_classes
=
max_num_images
image_attention_mask
,
num_classes
=
max_num_images
)
)
else
:
else
:
# in full language mode we set the image mask to all-0s
# in full language mode we set the image mask to all-0s
image_attention_mask
=
torch
.
zeros
(
image_attention_mask
=
torch
.
zeros
(
output_input_ids
.
shape
[
0
],
output_input_ids
.
shape
[
1
],
1
,
dtype
=
torch
.
bool
output_input_ids
.
shape
[
0
],
output_input_ids
.
shape
[
1
],
1
,
dtype
=
torch
.
bool
,
)
)
return
BatchFeature
(
return
BatchFeature
(
...
...
server/text_generation_server/models/custom_modeling/idefics_vision.py
View file @
47954b81
...
@@ -75,7 +75,9 @@ class IdeficsVisionEmbeddings(nn.Module):
...
@@ -75,7 +75,9 @@ class IdeficsVisionEmbeddings(nn.Module):
self
.
image_size
=
config
.
image_size
self
.
image_size
=
config
.
image_size
self
.
patch_size
=
config
.
patch_size
self
.
patch_size
=
config
.
patch_size
self
.
class_embedding
=
nn
.
Parameter
(
weights
.
get_tensor
(
f
"
{
prefix
}
.class_embedding"
))
self
.
class_embedding
=
nn
.
Parameter
(
weights
.
get_tensor
(
f
"
{
prefix
}
.class_embedding"
)
)
self
.
patch_embedding
=
nn
.
Conv2d
.
load_no_bias
(
self
.
patch_embedding
=
nn
.
Conv2d
.
load_no_bias
(
prefix
=
f
"
{
prefix
}
.patch_embedding"
,
prefix
=
f
"
{
prefix
}
.patch_embedding"
,
...
@@ -91,12 +93,16 @@ class IdeficsVisionEmbeddings(nn.Module):
...
@@ -91,12 +93,16 @@ class IdeficsVisionEmbeddings(nn.Module):
self
.
position_embedding
=
TensorParallelEmbedding
(
self
.
position_embedding
=
TensorParallelEmbedding
(
prefix
=
"model.vision_model.embeddings.position_embedding"
,
weights
=
weights
prefix
=
"model.vision_model.embeddings.position_embedding"
,
weights
=
weights
)
)
self
.
position_ids
=
torch
.
arange
(
self
.
num_positions
).
expand
((
1
,
-
1
)).
to
(
device
=
weights
.
device
)
self
.
position_ids
=
(
torch
.
arange
(
self
.
num_positions
).
expand
((
1
,
-
1
)).
to
(
device
=
weights
.
device
)
)
def
forward
(
self
,
pixel_values
:
torch
.
FloatTensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
pixel_values
:
torch
.
FloatTensor
)
->
torch
.
Tensor
:
batch_size
=
pixel_values
.
shape
[
0
]
batch_size
=
pixel_values
.
shape
[
0
]
target_dtype
=
self
.
patch_embedding
.
weight
.
dtype
target_dtype
=
self
.
patch_embedding
.
weight
.
dtype
patch_embeds
=
self
.
patch_embedding
(
pixel_values
.
to
(
dtype
=
target_dtype
))
# shape = [*, width, grid, grid]
patch_embeds
=
self
.
patch_embedding
(
pixel_values
.
to
(
dtype
=
target_dtype
)
)
# shape = [*, width, grid, grid]
patch_embeds
=
patch_embeds
.
flatten
(
2
).
transpose
(
1
,
2
)
patch_embeds
=
patch_embeds
.
flatten
(
2
).
transpose
(
1
,
2
)
class_embeds
=
self
.
class_embedding
.
expand
(
batch_size
,
1
,
-
1
)
class_embeds
=
self
.
class_embedding
.
expand
(
batch_size
,
1
,
-
1
)
...
@@ -132,7 +138,6 @@ class IdeficsVisionAttention(nn.Module):
...
@@ -132,7 +138,6 @@ class IdeficsVisionAttention(nn.Module):
self
.
num_heads
=
self
.
num_heads
//
weights
.
process_group
.
size
()
self
.
num_heads
=
self
.
num_heads
//
weights
.
process_group
.
size
()
self
.
embed_dim
=
self
.
embed_dim
//
weights
.
process_group
.
size
()
self
.
embed_dim
=
self
.
embed_dim
//
weights
.
process_group
.
size
()
self
.
k_proj
=
TensorParallelColumnLinear
.
load
(
self
.
k_proj
=
TensorParallelColumnLinear
.
load
(
config
,
prefix
=
f
"
{
prefix
}
.k_proj"
,
weights
=
weights
,
bias
=
True
config
,
prefix
=
f
"
{
prefix
}
.k_proj"
,
weights
=
weights
,
bias
=
True
)
)
...
@@ -147,7 +152,11 @@ class IdeficsVisionAttention(nn.Module):
...
@@ -147,7 +152,11 @@ class IdeficsVisionAttention(nn.Module):
)
)
def
_shape
(
self
,
tensor
:
torch
.
Tensor
,
seq_len
:
int
,
bsz
:
int
):
def
_shape
(
self
,
tensor
:
torch
.
Tensor
,
seq_len
:
int
,
bsz
:
int
):
return
tensor
.
view
(
bsz
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
).
contiguous
()
return
(
tensor
.
view
(
bsz
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
.
contiguous
()
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -186,7 +195,10 @@ class IdeficsVisionAttention(nn.Module):
...
@@ -186,7 +195,10 @@ class IdeficsVisionAttention(nn.Module):
f
"Attention mask should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"Attention mask should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is"
f
"
{
causal_attention_mask
.
size
()
}
"
f
"
{
causal_attention_mask
.
size
()
}
"
)
)
attn_weights
=
attn_weights
.
view
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
+
causal_attention_mask
attn_weights
=
(
attn_weights
.
view
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
+
causal_attention_mask
)
attn_weights
=
attn_weights
.
view
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
=
attn_weights
.
view
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
if
attention_mask
is
not
None
:
if
attention_mask
is
not
None
:
...
@@ -194,7 +206,10 @@ class IdeficsVisionAttention(nn.Module):
...
@@ -194,7 +206,10 @@ class IdeficsVisionAttention(nn.Module):
raise
ValueError
(
raise
ValueError
(
f
"Attention mask should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is
{
attention_mask
.
size
()
}
"
f
"Attention mask should be of size
{
(
bsz
,
1
,
tgt_len
,
src_len
)
}
, but is
{
attention_mask
.
size
()
}
"
)
)
attn_weights
=
attn_weights
.
view
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
+
attention_mask
attn_weights
=
(
attn_weights
.
view
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
+
attention_mask
)
attn_weights
=
attn_weights
.
view
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
=
attn_weights
.
view
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
=
nn
.
functional
.
softmax
(
attn_weights
,
dim
=-
1
)
attn_weights
=
nn
.
functional
.
softmax
(
attn_weights
,
dim
=-
1
)
...
@@ -204,12 +219,18 @@ class IdeficsVisionAttention(nn.Module):
...
@@ -204,12 +219,18 @@ class IdeficsVisionAttention(nn.Module):
# make sure that attn_weights keeps its gradient.
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to reshaped
# In order to do so, attn_weights have to reshaped
# twice and have to be reused in the following
# twice and have to be reused in the following
attn_weights_reshaped
=
attn_weights
.
view
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights_reshaped
=
attn_weights
.
view
(
attn_weights
=
attn_weights_reshaped
.
view
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
=
attn_weights_reshaped
.
view
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
else
:
else
:
attn_weights_reshaped
=
None
attn_weights_reshaped
=
None
attn_probs
=
nn
.
functional
.
dropout
(
attn_weights
,
p
=
self
.
dropout
,
training
=
self
.
training
)
attn_probs
=
nn
.
functional
.
dropout
(
attn_weights
,
p
=
self
.
dropout
,
training
=
self
.
training
)
attn_output
=
torch
.
bmm
(
attn_probs
,
value_states
)
attn_output
=
torch
.
bmm
(
attn_probs
,
value_states
)
...
@@ -253,11 +274,15 @@ class IdeficsVisionEncoderLayer(nn.Module):
...
@@ -253,11 +274,15 @@ class IdeficsVisionEncoderLayer(nn.Module):
def
__init__
(
self
,
prefix
,
config
,
weights
):
def
__init__
(
self
,
prefix
,
config
,
weights
):
super
().
__init__
()
super
().
__init__
()
self
.
embed_dim
=
config
.
hidden_size
self
.
embed_dim
=
config
.
hidden_size
self
.
self_attn
=
IdeficsVisionAttention
(
prefix
=
f
"
{
prefix
}
.self_attn"
,
config
=
config
,
weights
=
weights
)
self
.
self_attn
=
IdeficsVisionAttention
(
prefix
=
f
"
{
prefix
}
.self_attn"
,
config
=
config
,
weights
=
weights
)
self
.
layer_norm1
=
nn
.
LayerNorm
.
load
(
self
.
layer_norm1
=
nn
.
LayerNorm
.
load
(
prefix
=
f
"
{
prefix
}
.layer_norm1"
,
weights
=
weights
,
eps
=
config
.
layer_norm_eps
prefix
=
f
"
{
prefix
}
.layer_norm1"
,
weights
=
weights
,
eps
=
config
.
layer_norm_eps
)
)
self
.
mlp
=
IdeficsVisionMLP
(
prefix
=
f
"
{
prefix
}
.mlp"
,
config
=
config
,
weights
=
weights
)
self
.
mlp
=
IdeficsVisionMLP
(
prefix
=
f
"
{
prefix
}
.mlp"
,
config
=
config
,
weights
=
weights
)
self
.
layer_norm2
=
nn
.
LayerNorm
.
load
(
self
.
layer_norm2
=
nn
.
LayerNorm
.
load
(
prefix
=
f
"
{
prefix
}
.layer_norm2"
,
weights
=
weights
,
eps
=
config
.
layer_norm_eps
prefix
=
f
"
{
prefix
}
.layer_norm2"
,
weights
=
weights
,
eps
=
config
.
layer_norm_eps
)
)
...
@@ -318,7 +343,11 @@ class IdeficsVisionEncoder(nn.Module):
...
@@ -318,7 +343,11 @@ class IdeficsVisionEncoder(nn.Module):
self
.
config
=
config
self
.
config
=
config
self
.
layers
=
nn
.
ModuleList
(
self
.
layers
=
nn
.
ModuleList
(
[
[
IdeficsVisionEncoderLayer
(
prefix
=
f
"
{
prefix
}
.encoder.layers.
{
layer_id
}
"
,
config
=
config
,
weights
=
weights
)
IdeficsVisionEncoderLayer
(
prefix
=
f
"
{
prefix
}
.encoder.layers.
{
layer_id
}
"
,
config
=
config
,
weights
=
weights
,
)
for
layer_id
in
range
(
config
.
num_hidden_layers
)
for
layer_id
in
range
(
config
.
num_hidden_layers
)
]
]
)
)
...
@@ -362,11 +391,19 @@ class IdeficsVisionEncoder(nn.Module):
...
@@ -362,11 +391,19 @@ class IdeficsVisionEncoder(nn.Module):
return_dict (`bool`, *optional*):
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
"""
output_attentions
=
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
output_attentions
=
(
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
)
output_hidden_states
=
(
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
return_dict
=
(
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
)
)
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
encoder_states
=
()
if
output_hidden_states
else
None
encoder_states
=
()
if
output_hidden_states
else
None
all_attentions
=
()
if
output_attentions
else
None
all_attentions
=
()
if
output_attentions
else
None
...
@@ -406,9 +443,15 @@ class IdeficsVisionEncoder(nn.Module):
...
@@ -406,9 +443,15 @@ class IdeficsVisionEncoder(nn.Module):
encoder_states
=
encoder_states
+
(
hidden_states
,)
encoder_states
=
encoder_states
+
(
hidden_states
,)
if
not
return_dict
:
if
not
return_dict
:
return
tuple
(
v
for
v
in
[
hidden_states
,
encoder_states
,
all_attentions
]
if
v
is
not
None
)
return
tuple
(
v
for
v
in
[
hidden_states
,
encoder_states
,
all_attentions
]
if
v
is
not
None
)
return
BaseModelOutput
(
return
BaseModelOutput
(
last_hidden_state
=
hidden_states
,
hidden_states
=
encoder_states
,
attentions
=
all_attentions
last_hidden_state
=
hidden_states
,
hidden_states
=
encoder_states
,
attentions
=
all_attentions
,
)
)
...
@@ -419,13 +462,19 @@ class IdeficsVisionTransformer(nn.Module):
...
@@ -419,13 +462,19 @@ class IdeficsVisionTransformer(nn.Module):
self
.
config
=
config
self
.
config
=
config
embed_dim
=
config
.
hidden_size
embed_dim
=
config
.
hidden_size
self
.
embeddings
=
IdeficsVisionEmbeddings
(
prefix
=
f
"
{
prefix
}
.embeddings"
,
config
=
config
,
weights
=
weights
)
self
.
embeddings
=
IdeficsVisionEmbeddings
(
prefix
=
f
"
{
prefix
}
.embeddings"
,
config
=
config
,
weights
=
weights
)
self
.
pre_layrnorm
=
nn
.
LayerNorm
.
load
(
self
.
pre_layrnorm
=
nn
.
LayerNorm
.
load
(
prefix
=
f
"
{
prefix
}
.pre_layrnorm"
,
weights
=
weights
,
eps
=
config
.
layer_norm_eps
prefix
=
f
"
{
prefix
}
.pre_layrnorm"
,
weights
=
weights
,
eps
=
config
.
layer_norm_eps
)
)
self
.
encoder
=
IdeficsVisionEncoder
(
prefix
=
prefix
,
config
=
config
,
weights
=
weights
)
self
.
encoder
=
IdeficsVisionEncoder
(
prefix
=
prefix
,
config
=
config
,
weights
=
weights
)
self
.
post_layernorm
=
nn
.
LayerNorm
.
load
(
self
.
post_layernorm
=
nn
.
LayerNorm
.
load
(
prefix
=
f
"
{
prefix
}
.post_layernorm"
,
weights
=
weights
,
eps
=
config
.
layer_norm_eps
prefix
=
f
"
{
prefix
}
.post_layernorm"
,
weights
=
weights
,
eps
=
config
.
layer_norm_eps
,
)
)
# copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer.forward
# copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer.forward
...
@@ -440,11 +489,19 @@ class IdeficsVisionTransformer(nn.Module):
...
@@ -440,11 +489,19 @@ class IdeficsVisionTransformer(nn.Module):
Returns:
Returns:
"""
"""
output_attentions
=
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
output_attentions
=
(
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
)
output_hidden_states
=
(
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
return_dict
=
(
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
)
)
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
if
pixel_values
is
None
:
if
pixel_values
is
None
:
raise
ValueError
(
"You have to specify pixel_values"
)
raise
ValueError
(
"You have to specify pixel_values"
)
...
...
server/text_generation_server/models/custom_modeling/neox_modeling.py
View file @
47954b81
...
@@ -49,7 +49,10 @@ from text_generation_server.utils.layers import (
...
@@ -49,7 +49,10 @@ from text_generation_server.utils.layers import (
CUSTOM_KERNELS_ENABLED
=
False
CUSTOM_KERNELS_ENABLED
=
False
if
torch
.
cuda
.
is_available
()
and
not
os
.
environ
.
get
(
"DISABLE_CUSTOM_KERNELS"
,
"False"
)
==
"True"
:
if
(
torch
.
cuda
.
is_available
()
and
not
os
.
environ
.
get
(
"DISABLE_CUSTOM_KERNELS"
,
"False"
)
==
"True"
):
try
:
try
:
from
custom_kernels
import
fused_attention_cuda
from
custom_kernels
import
fused_attention_cuda
...
...
server/text_generation_server/models/flash_causal_lm.py
View file @
47954b81
...
@@ -1005,9 +1005,12 @@ class FlashCausalLM(Model):
...
@@ -1005,9 +1005,12 @@ class FlashCausalLM(Model):
# Decode generated tokens
# Decode generated tokens
output_text
,
_
,
_
=
self
.
decode_token
(
output_text
,
_
,
_
=
self
.
decode_token
(
all_input_ids
,
all_input_ids
,
prefix_offset
=
len
(
all_input_ids
)
-
stopping_criteria
.
current_tokens
-
1
,
prefix_offset
=
len
(
all_input_ids
)
read_offset
=
len
(
all_input_ids
)
-
stopping_criteria
.
current_tokens
,
-
stopping_criteria
.
current_tokens
skip_special_tokens
=
True
-
1
,
read_offset
=
len
(
all_input_ids
)
-
stopping_criteria
.
current_tokens
,
skip_special_tokens
=
True
,
)
)
generated_text
=
GeneratedText
(
generated_text
=
GeneratedText
(
output_text
,
output_text
,
...
...
server/text_generation_server/models/idefics_causal_lm.py
View file @
47954b81
...
@@ -8,7 +8,13 @@ import re
...
@@ -8,7 +8,13 @@ import re
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
opentelemetry
import
trace
from
opentelemetry
import
trace
from
transformers
import
AutoProcessor
,
AutoTokenizer
,
AutoModelForCausalLM
,
PreTrainedTokenizerBase
,
ProcessorMixin
from
transformers
import
(
AutoProcessor
,
AutoTokenizer
,
AutoModelForCausalLM
,
PreTrainedTokenizerBase
,
ProcessorMixin
,
)
from
typing
import
Optional
,
Tuple
,
List
,
Type
,
Dict
from
typing
import
Optional
,
Tuple
,
List
,
Type
,
Dict
from
text_generation_server.models
import
Model
from
text_generation_server.models
import
Model
...
@@ -23,7 +29,8 @@ from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sam
...
@@ -23,7 +29,8 @@ from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sam
import
re
import
re
IMAGES
=
re
.
compile
(
r
'!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)'
)
IMAGES
=
re
.
compile
(
r
"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)"
)
def
split
(
string
):
def
split
(
string
):
parts
=
[]
parts
=
[]
...
@@ -41,6 +48,7 @@ def split(string):
...
@@ -41,6 +48,7 @@ def split(string):
return
parts
return
parts
tracer
=
trace
.
get_tracer
(
__name__
)
tracer
=
trace
.
get_tracer
(
__name__
)
...
@@ -141,8 +149,12 @@ class IdeficsCausalLMBatch(Batch):
...
@@ -141,8 +149,12 @@ class IdeficsCausalLMBatch(Batch):
).
to
(
device
)
).
to
(
device
)
for
_
in
pb
.
requests
:
for
_
in
pb
.
requests
:
input_len
=
tokenized_inputs
[
"input_ids"
].
shape
[
1
]
input_len
=
tokenized_inputs
[
"input_ids"
].
shape
[
1
]
prefix_offsets
.
append
(
input_len
-
5
)
# To decode without potential fallbacks errors
prefix_offsets
.
append
(
read_offsets
.
append
(
input_len
)
# To decode without potential fallbacks errors
input_len
-
5
)
# To decode without potential fallbacks errors
read_offsets
.
append
(
input_len
)
# To decode without potential fallbacks errors
input_lengths
=
tokenized_inputs
[
"attention_mask"
].
sum
(
1
)
input_lengths
=
tokenized_inputs
[
"attention_mask"
].
sum
(
1
)
max_input_length
=
input_lengths
.
max
()
max_input_length
=
input_lengths
.
max
()
...
@@ -158,14 +170,21 @@ class IdeficsCausalLMBatch(Batch):
...
@@ -158,14 +170,21 @@ class IdeficsCausalLMBatch(Batch):
attention_mask
[:,
:
max_input_length
]
=
tokenized_inputs
[
"attention_mask"
]
attention_mask
[:,
:
max_input_length
]
=
tokenized_inputs
[
"attention_mask"
]
# Do the same for image_attention_mask
# Do the same for image_attention_mask
image_attention_mask
=
input_ids
.
new_zeros
(
image_attention_mask
=
input_ids
.
new_zeros
(
(
pb
.
size
,
max_input_length
+
padding_right_offset
,
tokenized_inputs
[
"pixel_values"
].
size
(
1
))
(
pb
.
size
,
max_input_length
+
padding_right_offset
,
tokenized_inputs
[
"pixel_values"
].
size
(
1
),
)
)
image_attention_mask
[:,
:
max_input_length
,
:]
=
tokenized_inputs
[
"image_attention_mask"
]
)
image_attention_mask
[:,
:
max_input_length
,
:]
=
tokenized_inputs
[
"image_attention_mask"
]
position_ids
=
tokenized_inputs
[
"attention_mask"
].
long
().
cumsum
(
-
1
)
-
1
position_ids
=
tokenized_inputs
[
"attention_mask"
].
long
().
cumsum
(
-
1
)
-
1
position_ids
.
masked_fill_
(
tokenized_inputs
[
"attention_mask"
]
==
0
,
1
)
position_ids
.
masked_fill_
(
tokenized_inputs
[
"attention_mask"
]
==
0
,
1
)
all_input_ids
=
tokenized_inputs
[
"input_ids"
].
T
.
split
(
1
,
dim
=
1
)
# It's input_ids but splitted into a tuple of tensors where each tensor is (seq_len, 1) size. It is then transformed into a list
all_input_ids
=
tokenized_inputs
[
"input_ids"
].
T
.
split
(
1
,
dim
=
1
)
# It's input_ids but splitted into a tuple of tensors where each tensor is (seq_len, 1) size. It is then transformed into a list
max_tokens
=
len
(
inputs
)
*
(
max_input_length
+
max_decode_tokens
)
max_tokens
=
len
(
inputs
)
*
(
max_input_length
+
max_decode_tokens
)
...
@@ -259,7 +278,7 @@ class IdeficsCausalLMBatch(Batch):
...
@@ -259,7 +278,7 @@ class IdeficsCausalLMBatch(Batch):
self
.
image_attention_mask
.
shape
[
1
]
-
self
.
padding_right_offset
self
.
image_attention_mask
.
shape
[
1
]
-
self
.
padding_right_offset
)
)
+
new_padding_right_offset
,
+
new_padding_right_offset
,
:
:
,
]
]
if
self
.
image_hidden_states
is
None
:
if
self
.
image_hidden_states
is
None
:
image_hidden_states
=
None
image_hidden_states
=
None
...
@@ -308,7 +327,9 @@ class IdeficsCausalLMBatch(Batch):
...
@@ -308,7 +327,9 @@ class IdeficsCausalLMBatch(Batch):
@
classmethod
@
classmethod
@
tracer
.
start_as_current_span
(
"concatenate"
)
@
tracer
.
start_as_current_span
(
"concatenate"
)
def
concatenate
(
cls
,
batches
:
List
[
"IdeficsCausalLMBatch"
])
->
"IdeficsCausalLMBatch"
:
def
concatenate
(
cls
,
batches
:
List
[
"IdeficsCausalLMBatch"
]
)
->
"IdeficsCausalLMBatch"
:
# It adds new requests to the batch
# It adds new requests to the batch
# Used for padding
# Used for padding
total_batch_size
=
0
total_batch_size
=
0
...
@@ -383,12 +404,20 @@ class IdeficsCausalLMBatch(Batch):
...
@@ -383,12 +404,20 @@ class IdeficsCausalLMBatch(Batch):
curr_batch_max_num_images
=
batch
.
pixel_values
.
size
(
1
)
curr_batch_max_num_images
=
batch
.
pixel_values
.
size
(
1
)
if
pixel_values
is
None
:
if
pixel_values
is
None
:
pixel_values
=
batch
.
pixel_values
.
new_zeros
((
total_batch_size
,
max_num_images
,
3
,
224
,
224
))
pixel_values
=
batch
.
pixel_values
.
new_zeros
(
pixel_values
[
start_index
:
end_index
,
:
curr_batch_max_num_images
]
=
batch
.
pixel_values
(
total_batch_size
,
max_num_images
,
3
,
224
,
224
)
)
pixel_values
[
start_index
:
end_index
,
:
curr_batch_max_num_images
]
=
batch
.
pixel_values
if
image_attention_mask
is
None
:
if
image_attention_mask
is
None
:
image_attention_mask
=
batch
.
image_attention_mask
.
new_zeros
(
image_attention_mask
=
batch
.
image_attention_mask
.
new_zeros
(
(
total_batch_size
,
max_input_length
+
padding_right_offset
,
max_num_images
)
(
total_batch_size
,
max_input_length
+
padding_right_offset
,
max_num_images
,
)
)
)
# We need to slice the attention mask to remove padding from previous steps
# We need to slice the attention mask to remove padding from previous steps
...
@@ -409,11 +438,9 @@ class IdeficsCausalLMBatch(Batch):
...
@@ -409,11 +438,9 @@ class IdeficsCausalLMBatch(Batch):
image_attention_mask
[
image_attention_mask
[
start_index
:
end_index
,
start_index
:
end_index
,
left_offset
:
-
padding_right_offset
,
left_offset
:
-
padding_right_offset
,
:
curr_batch_max_num_images
:
curr_batch_max_num_images
,
]
=
batch
.
image_attention_mask
[
]
=
batch
.
image_attention_mask
[
:,
:,
batch_left_offset
:
-
batch
.
padding_right_offset
,
:
batch_left_offset
:
-
batch
.
padding_right_offset
,
:
]
]
# Create empty tensor
# Create empty tensor
...
@@ -550,7 +577,9 @@ class IdeficsCausalLM(Model):
...
@@ -550,7 +577,9 @@ class IdeficsCausalLM(Model):
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
trust_remote_code
:
bool
=
False
,
trust_remote_code
:
bool
=
False
,
):
):
from
text_generation_server.models.custom_modeling.idefics_modeling
import
IdeficsForVisionText2Text
from
text_generation_server.models.custom_modeling.idefics_modeling
import
(
IdeficsForVisionText2Text
,
)
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
"cuda"
)
...
@@ -650,9 +679,13 @@ class IdeficsCausalLM(Model):
...
@@ -650,9 +679,13 @@ class IdeficsCausalLM(Model):
# this is due to the nature IDEFICS: it's an encoder decoder, and so when decoding, only the currently generated
# this is due to the nature IDEFICS: it's an encoder decoder, and so when decoding, only the currently generated
# token need to attend to the encoder hidden states (i.e. the vision encoder)
# token need to attend to the encoder hidden states (i.e. the vision encoder)
# Also see seq2seq_lm.Seq2SeqLM.generate_token which has roughly the same logic
# Also see seq2seq_lm.Seq2SeqLM.generate_token which has roughly the same logic
image_attention_mask
=
batch
.
image_attention_mask
[:,
-
(
batch
.
padding_right_offset
+
1
)].
unsqueeze
(
1
)
image_attention_mask
=
batch
.
image_attention_mask
[
:,
-
(
batch
.
padding_right_offset
+
1
)
].
unsqueeze
(
1
)
else
:
else
:
image_attention_mask
=
batch
.
image_attention_mask
[:,
:
-
batch
.
padding_right_offset
]
image_attention_mask
=
batch
.
image_attention_mask
[
:,
:
-
batch
.
padding_right_offset
]
logits
,
past
,
image_hidden_states
=
self
.
forward
(
logits
,
past
,
image_hidden_states
=
self
.
forward
(
input_ids
=
batch
.
input_ids
,
input_ids
=
batch
.
input_ids
,
...
@@ -725,9 +758,12 @@ class IdeficsCausalLM(Model):
...
@@ -725,9 +758,12 @@ class IdeficsCausalLM(Model):
# Decode generated tokens
# Decode generated tokens
output_text
,
_
,
_
=
self
.
decode_token
(
output_text
,
_
,
_
=
self
.
decode_token
(
all_input_ids
[:,
0
],
all_input_ids
[:,
0
],
prefix_offset
=
len
(
all_input_ids
)
-
stopping_criteria
.
current_tokens
-
1
,
prefix_offset
=
len
(
all_input_ids
)
read_offset
=
len
(
all_input_ids
)
-
stopping_criteria
.
current_tokens
,
-
stopping_criteria
.
current_tokens
skip_special_tokens
=
True
-
1
,
read_offset
=
len
(
all_input_ids
)
-
stopping_criteria
.
current_tokens
,
skip_special_tokens
=
True
,
)
)
# Get seed
# Get seed
if
isinstance
(
next_token_chooser
.
choice
,
Sampling
):
if
isinstance
(
next_token_chooser
.
choice
,
Sampling
):
...
@@ -761,7 +797,7 @@ class IdeficsCausalLM(Model):
...
@@ -761,7 +797,7 @@ class IdeficsCausalLM(Model):
else
:
else
:
prefill_tokens
=
None
prefill_tokens
=
None
top_tokens
=
None
top_tokens
=
None
generation
=
Generation
(
generation
=
Generation
(
request
.
id
,
request
.
id
,
...
@@ -771,7 +807,7 @@ class IdeficsCausalLM(Model):
...
@@ -771,7 +807,7 @@ class IdeficsCausalLM(Model):
next_token_text
,
next_token_text
,
next_token_id_squeezed
.
item
()
in
self
.
all_special_ids
,
next_token_id_squeezed
.
item
()
in
self
.
all_special_ids
,
generated_text
,
generated_text
,
top_tokens
top_tokens
,
)
)
generations
.
append
(
generation
)
generations
.
append
(
generation
)
...
@@ -793,7 +829,9 @@ class IdeficsCausalLM(Model):
...
@@ -793,7 +829,9 @@ class IdeficsCausalLM(Model):
# Update attention_mask as we added a new token to input_ids
# Update attention_mask as we added a new token to input_ids
batch
.
attention_mask
[:,
-
batch
.
padding_right_offset
]
=
1
batch
.
attention_mask
[:,
-
batch
.
padding_right_offset
]
=
1
batch
.
image_attention_mask
[:,
-
batch
.
padding_right_offset
,
:]
=
batch
.
image_attention_mask
[:,
-
(
batch
.
padding_right_offset
+
1
),
:]
batch
.
image_attention_mask
[
:,
-
batch
.
padding_right_offset
,
:
]
=
batch
.
image_attention_mask
[:,
-
(
batch
.
padding_right_offset
+
1
),
:]
# Decrease right offset
# Decrease right offset
batch
.
padding_right_offset
-=
1
batch
.
padding_right_offset
-=
1
...
...
server/text_generation_server/models/model.py
View file @
47954b81
...
@@ -71,7 +71,8 @@ class Model(ABC):
...
@@ -71,7 +71,8 @@ class Model(ABC):
# The prefix text is necessary only to defeat cleanup algorithms in the decode
# The prefix text is necessary only to defeat cleanup algorithms in the decode
# which decide to add a space or not depending on the surrounding ids.
# which decide to add a space or not depending on the surrounding ids.
prefix_text
=
self
.
tokenizer
.
decode
(
prefix_text
=
self
.
tokenizer
.
decode
(
all_input_ids
[
prefix_offset
:
read_offset
],
skip_special_tokens
=
skip_special_tokens
all_input_ids
[
prefix_offset
:
read_offset
],
skip_special_tokens
=
skip_special_tokens
,
)
)
new_text
=
self
.
tokenizer
.
decode
(
new_text
=
self
.
tokenizer
.
decode
(
all_input_ids
[
prefix_offset
:],
skip_special_tokens
=
skip_special_tokens
all_input_ids
[
prefix_offset
:],
skip_special_tokens
=
skip_special_tokens
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment