Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
06d0e880
Unverified
Commit
06d0e880
authored
Jul 16, 2024
by
Daniël de Kok
Committed by
GitHub
Jul 16, 2024
Browse files
Add support for AWQ-quantized Idefics2 (#2233)
Fixes #2036.
parent
0ad7f6f8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
39 additions
and
11 deletions
+39
-11
server/text_generation_server/models/custom_modeling/idefics2.py
...text_generation_server/models/custom_modeling/idefics2.py
+24
-11
server/text_generation_server/utils/weights.py
server/text_generation_server/utils/weights.py
+15
-0
No files found.
server/text_generation_server/models/custom_modeling/idefics2.py
View file @
06d0e880
...
@@ -34,6 +34,7 @@ from text_generation_server.layers import (
...
@@ -34,6 +34,7 @@ from text_generation_server.layers import (
TensorParallelEmbedding
,
TensorParallelEmbedding
,
TensorParallelRowLinear
,
TensorParallelRowLinear
,
)
)
from
text_generation_server.utils.weights
import
DefaultWeightsLoader
def
repeat_kv
(
hidden_states
:
torch
.
Tensor
,
n_rep
:
int
)
->
torch
.
Tensor
:
def
repeat_kv
(
hidden_states
:
torch
.
Tensor
,
n_rep
:
int
)
->
torch
.
Tensor
:
...
@@ -682,7 +683,7 @@ class Idefics2Connector(nn.Module):
...
@@ -682,7 +683,7 @@ class Idefics2Connector(nn.Module):
class
Idefics2ForConditionalGeneration
(
nn
.
Module
):
class
Idefics2ForConditionalGeneration
(
nn
.
Module
):
def
__init__
(
self
,
prefix
,
config
,
weights
):
def
__init__
(
self
,
prefix
,
config
,
weights
):
super
().
__init__
()
super
().
__init__
()
config
.
vision_config
.
quantize
=
c
on
fig
.
quantiz
e
config
.
vision_config
.
quantize
=
N
one
config
.
vision_config
.
speculator
=
config
.
speculator
config
.
vision_config
.
speculator
=
config
.
speculator
config
.
text_config
.
quantize
=
config
.
quantize
config
.
text_config
.
quantize
=
config
.
quantize
config
.
text_config
.
speculator
=
config
.
speculator
config
.
text_config
.
speculator
=
config
.
speculator
...
@@ -695,16 +696,28 @@ class Idefics2ForConditionalGeneration(nn.Module):
...
@@ -695,16 +696,28 @@ class Idefics2ForConditionalGeneration(nn.Module):
name
=
"text_model"
,
name
=
"text_model"
,
)
)
self
.
dtype
=
weights
.
dtype
self
.
dtype
=
weights
.
dtype
self
.
vision_model
=
Idefics2VisionTransformer
(
prefix
=
f
"
{
prefix
}
.model.vision_model"
if
prefix
else
"model.vision_model"
,
# The vision and connector models are not quantized.
config
=
vision_config
,
with
weights
.
use_loader
(
DefaultWeightsLoader
()):
weights
=
weights
,
self
.
vision_model
=
Idefics2VisionTransformer
(
)
prefix
=
(
self
.
connector
=
Idefics2Connector
(
f
"
{
prefix
}
.model.vision_model"
if
prefix
else
"model.vision_model"
prefix
=
f
"
{
prefix
}
.model.connector"
if
prefix
else
"model.connector"
,
),
config
=
config
,
config
=
vision_config
,
weights
=
weights
,
weights
=
weights
,
)
)
quantize
=
config
.
quantize
try
:
config
.
quantize
=
None
self
.
connector
=
Idefics2Connector
(
prefix
=
f
"
{
prefix
}
.model.connector"
if
prefix
else
"model.connector"
,
config
=
config
,
weights
=
weights
,
)
finally
:
config
.
quantize
=
quantize
self
.
config
=
config
self
.
config
=
config
self
.
image_seq_len
=
config
.
perceiver_config
.
resampler_n_latents
self
.
image_seq_len
=
config
.
perceiver_config
.
resampler_n_latents
self
.
image_token_id
=
config
.
image_token_id
self
.
image_token_id
=
config
.
image_token_id
...
...
server/text_generation_server/utils/weights.py
View file @
06d0e880
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
contextlib
import
contextmanager
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Dict
,
List
,
Optional
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Union
from
safetensors
import
safe_open
from
safetensors
import
safe_open
...
@@ -306,6 +307,20 @@ class Weights:
...
@@ -306,6 +307,20 @@ class Weights:
def
get_weights_row
(
self
,
prefix
:
str
):
def
get_weights_row
(
self
,
prefix
:
str
):
return
self
.
weights_loader
.
get_weights_row
(
self
,
prefix
)
return
self
.
weights_loader
.
get_weights_row
(
self
,
prefix
)
@
contextmanager
def
use_loader
(
self
,
weights_loader
:
WeightsLoader
):
"""
This method is a context manager that can be used to use `Weights` with
a different loader for the duration of the context.
"""
old_loader
=
self
.
weights_loader
self
.
weights_loader
=
weights_loader
try
:
yield
finally
:
self
.
weights_loader
=
old_loader
def
_blocks_to_block_sizes
(
total_size
:
int
,
blocks
:
Union
[
int
,
List
[
int
]])
->
List
[
int
]:
def
_blocks_to_block_sizes
(
total_size
:
int
,
blocks
:
Union
[
int
,
List
[
int
]])
->
List
[
int
]:
"""
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment