Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
e14ae3b5
Unverified
Commit
e14ae3b5
authored
Apr 19, 2023
by
OlivierDehaene
Committed by
GitHub
Apr 19, 2023
Browse files
feat(server): support quantization for flash models (#200)
closes #197
parent
2475aede
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
196 additions
and
83 deletions
+196
-83
server/text_generation_server/models/__init__.py
server/text_generation_server/models/__init__.py
+5
-3
server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
...ion_server/models/custom_modeling/flash_llama_modeling.py
+49
-13
server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
...tion_server/models/custom_modeling/flash_neox_modeling.py
+60
-17
server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
...erver/models/custom_modeling/flash_santacoder_modeling.py
+49
-13
server/text_generation_server/models/flash_causal_lm.py
server/text_generation_server/models/flash_causal_lm.py
+2
-4
server/text_generation_server/models/flash_llama.py
server/text_generation_server/models/flash_llama.py
+7
-12
server/text_generation_server/models/flash_neox.py
server/text_generation_server/models/flash_neox.py
+7
-6
server/text_generation_server/models/flash_santacoder.py
server/text_generation_server/models/flash_santacoder.py
+17
-15
No files found.
server/text_generation_server/models/__init__.py
View file @
e14ae3b5
...
@@ -26,7 +26,9 @@ try:
...
@@ -26,7 +26,9 @@ try:
FLASH_ATTENTION
=
torch
.
cuda
.
is_available
()
FLASH_ATTENTION
=
torch
.
cuda
.
is_available
()
except
ImportError
:
except
ImportError
:
logger
.
opt
(
exception
=
True
).
warning
(
"Could not import Flash Attention enabled models"
)
logger
.
opt
(
exception
=
True
).
warning
(
"Could not import Flash Attention enabled models"
)
FLASH_ATTENTION
=
False
FLASH_ATTENTION
=
False
__all__
=
[
__all__
=
[
...
@@ -88,10 +90,10 @@ def get_model(
...
@@ -88,10 +90,10 @@ def get_model(
raise
NotImplementedError
(
raise
NotImplementedError
(
FLASH_ATT_ERROR_MESSAGE
.
format
(
f
"Sharded Santacoder"
)
FLASH_ATT_ERROR_MESSAGE
.
format
(
f
"Sharded Santacoder"
)
)
)
return
FlashSantacoderSharded
(
model_id
,
revision
=
revision
)
return
FlashSantacoderSharded
(
model_id
,
revision
,
quantize
=
quantize
)
else
:
else
:
santacoder_cls
=
FlashSantacoder
if
FLASH_ATTENTION
else
SantaCoder
santacoder_cls
=
FlashSantacoder
if
FLASH_ATTENTION
else
SantaCoder
return
santacoder_cls
(
model_id
,
revision
,
quantize
)
return
santacoder_cls
(
model_id
,
revision
,
quantize
=
quantize
)
config
=
AutoConfig
.
from_pretrained
(
model_id
,
revision
=
revision
)
config
=
AutoConfig
.
from_pretrained
(
model_id
,
revision
=
revision
)
model_type
=
config
.
model_type
model_type
=
config
.
model_type
...
...
server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
View file @
e14ae3b5
...
@@ -33,6 +33,12 @@ import dropout_layer_norm
...
@@ -33,6 +33,12 @@ import dropout_layer_norm
from
flash_attn.layers.rotary
import
RotaryEmbedding
from
flash_attn.layers.rotary
import
RotaryEmbedding
HAS_BITS_AND_BYTES
=
True
try
:
from
bitsandbytes.nn
import
Linear8bitLt
except
ImportError
as
e
:
HAS_BITS_AND_BYTES
=
False
class
LlamaRMSNorm
(
nn
.
Module
):
class
LlamaRMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-6
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-6
):
...
@@ -94,14 +100,44 @@ class FastLinear(nn.Linear):
...
@@ -94,14 +100,44 @@ class FastLinear(nn.Linear):
dtype
=
None
,
dtype
=
None
,
)
->
None
:
)
->
None
:
super
(
FastLinear
,
self
).
__init__
(
in_features
,
out_features
,
bias
,
device
,
dtype
)
super
(
FastLinear
,
self
).
__init__
(
in_features
,
out_features
,
bias
,
device
,
dtype
)
self
.
quantized
=
False
self
.
bnb_linear
=
None
def
prepare_weights
(
self
,
quantize
:
bool
=
False
):
if
quantize
:
if
not
HAS_BITS_AND_BYTES
:
raise
ImportError
(
"bitsandbytes is not available on your machine either because it is not installed "
"or you don't have a GPU.
\n
"
"You can install it with `pip install bitsandbytes`."
)
def
transpose_weight
(
self
):
self
.
quantized
=
True
self
.
weight
=
nn
.
Parameter
(
self
.
weight
.
T
)
self
.
bnb_linear
=
Linear8bitLt
(
self
.
in_features
,
self
.
out_features
,
has_fp16_weights
=
False
,
threshold
=
6.0
,
bias
=
False
,
)
# Copy data to bnb_linear
self
.
bnb_linear
.
weight
.
data
=
self
.
weight
.
data
if
self
.
bias
is
not
None
:
self
.
bnb_linear
.
bias
=
nn
.
Parameter
(
self
.
bias
)
# Delete reference to data
self
.
weight
=
None
self
.
bias
=
None
else
:
self
.
weight
=
nn
.
Parameter
(
self
.
weight
.
T
)
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
bias
is
not
None
:
if
self
.
quantized
:
return
torch
.
addmm
(
self
.
bias
,
input
,
self
.
weight
)
return
self
.
bnb_linear
(
input
)
return
torch
.
matmul
(
input
,
self
.
weight
)
else
:
if
self
.
bias
is
not
None
:
return
torch
.
addmm
(
self
.
bias
,
input
,
self
.
weight
)
return
torch
.
matmul
(
input
,
self
.
weight
)
class
TensorParallelColumnLinear
(
FastLinear
):
class
TensorParallelColumnLinear
(
FastLinear
):
...
@@ -502,15 +538,15 @@ class FlashLlamaModel(torch.nn.Module):
...
@@ -502,15 +538,15 @@ class FlashLlamaModel(torch.nn.Module):
self
.
head_size
=
self
.
layers
[
0
].
self_attn
.
head_size
self
.
head_size
=
self
.
layers
[
0
].
self_attn
.
head_size
self
.
num_heads
=
self
.
layers
[
0
].
self_attn
.
num_heads
self
.
num_heads
=
self
.
layers
[
0
].
self_attn
.
num_heads
def
post_load_weights
(
self
):
def
post_load_weights
(
self
,
load_in_8bit
:
bool
=
False
):
if
isinstance
(
self
.
embed_tokens
,
TensorParallelEmbedding
):
if
isinstance
(
self
.
embed_tokens
,
TensorParallelEmbedding
):
self
.
embed_tokens
.
add_null_idx
()
self
.
embed_tokens
.
add_null_idx
()
for
layer
in
self
.
layers
:
for
layer
in
self
.
layers
:
layer
:
FlashLlamaLayer
layer
:
FlashLlamaLayer
layer
.
self_attn
.
query_key_value
.
transpose_weight
(
)
layer
.
self_attn
.
query_key_value
.
prepare_weights
(
load_in_8bit
)
layer
.
self_attn
.
o_proj
.
transpose_weight
(
)
layer
.
self_attn
.
o_proj
.
prepare_weights
(
load_in_8bit
)
layer
.
mlp
.
gate_up_proj
.
transpose_weight
(
)
layer
.
mlp
.
gate_up_proj
.
prepare_weights
(
load_in_8bit
)
layer
.
mlp
.
down_proj
.
transpose_weight
(
)
layer
.
mlp
.
down_proj
.
prepare_weights
(
load_in_8bit
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -592,9 +628,9 @@ class FlashLlamaForCausalLM(torch.nn.Module):
...
@@ -592,9 +628,9 @@ class FlashLlamaForCausalLM(torch.nn.Module):
else
:
else
:
self
.
lm_head
=
FastLinear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
)
self
.
lm_head
=
FastLinear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
)
def
post_load_weights
(
self
):
def
post_load_weights
(
self
,
load_in_8bit
:
bool
=
False
):
self
.
model
.
post_load_weights
()
self
.
model
.
post_load_weights
(
load_in_8bit
)
self
.
lm_head
.
transpos
e_weight
()
self
.
lm_head
.
prepar
e_weight
s
()
def
forward
(
def
forward
(
self
,
self
,
...
...
server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
View file @
e14ae3b5
...
@@ -35,6 +35,12 @@ import dropout_layer_norm
...
@@ -35,6 +35,12 @@ import dropout_layer_norm
from
flash_attn.layers.rotary
import
RotaryEmbedding
from
flash_attn.layers.rotary
import
RotaryEmbedding
HAS_BITS_AND_BYTES
=
True
try
:
from
bitsandbytes.nn
import
Linear8bitLt
except
ImportError
as
e
:
HAS_BITS_AND_BYTES
=
False
class
FastLayerNorm
(
nn
.
LayerNorm
):
class
FastLayerNorm
(
nn
.
LayerNorm
):
def
forward
(
self
,
hidden_states
,
residual
=
None
):
def
forward
(
self
,
hidden_states
,
residual
=
None
):
...
@@ -82,14 +88,44 @@ class FastLinear(nn.Linear):
...
@@ -82,14 +88,44 @@ class FastLinear(nn.Linear):
dtype
=
None
,
dtype
=
None
,
)
->
None
:
)
->
None
:
super
(
FastLinear
,
self
).
__init__
(
in_features
,
out_features
,
bias
,
device
,
dtype
)
super
(
FastLinear
,
self
).
__init__
(
in_features
,
out_features
,
bias
,
device
,
dtype
)
self
.
quantized
=
False
self
.
bnb_linear
=
None
def
prepare_weights
(
self
,
quantize
:
bool
=
False
):
if
quantize
:
if
not
HAS_BITS_AND_BYTES
:
raise
ImportError
(
"bitsandbytes is not available on your machine either because it is not installed "
"or you don't have a GPU.
\n
"
"You can install it with `pip install bitsandbytes`."
)
def
transpose_weight
(
self
):
self
.
quantized
=
True
self
.
weight
=
nn
.
Parameter
(
self
.
weight
.
T
)
self
.
bnb_linear
=
Linear8bitLt
(
self
.
in_features
,
self
.
out_features
,
has_fp16_weights
=
False
,
threshold
=
6.0
,
bias
=
False
,
)
# Copy data to bnb_linear
self
.
bnb_linear
.
weight
.
data
=
self
.
weight
.
data
if
self
.
bias
is
not
None
:
self
.
bnb_linear
.
bias
=
nn
.
Parameter
(
self
.
bias
)
# Delete reference to data
self
.
weight
=
None
self
.
bias
=
None
else
:
self
.
weight
=
nn
.
Parameter
(
self
.
weight
.
T
)
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
bias
is
not
None
:
if
self
.
quantized
:
return
torch
.
addmm
(
self
.
bias
,
input
,
self
.
weight
)
return
self
.
bnb_linear
(
input
)
return
torch
.
matmul
(
input
,
self
.
weight
)
else
:
if
self
.
bias
is
not
None
:
return
torch
.
addmm
(
self
.
bias
,
input
,
self
.
weight
)
return
torch
.
matmul
(
input
,
self
.
weight
)
class
TensorParallelColumnLinear
(
FastLinear
):
class
TensorParallelColumnLinear
(
FastLinear
):
...
@@ -552,23 +588,27 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
...
@@ -552,23 +588,27 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
self
.
head_size
=
self
.
layers
[
0
].
attention
.
head_size
self
.
head_size
=
self
.
layers
[
0
].
attention
.
head_size
self
.
num_heads
=
self
.
layers
[
0
].
attention
.
num_heads
self
.
num_heads
=
self
.
layers
[
0
].
attention
.
num_heads
def
post_load_weights
(
self
):
def
post_load_weights
(
self
,
load_in_8bit
=
False
):
if
isinstance
(
self
.
embed_in
,
TensorParallelEmbedding
):
if
isinstance
(
self
.
embed_in
,
TensorParallelEmbedding
):
self
.
embed_in
.
add_null_idx
()
self
.
embed_in
.
add_null_idx
()
for
layer
in
self
.
layers
:
for
layer
in
self
.
layers
:
layer
:
FlashNeoXLayer
layer
:
FlashNeoXLayer
layer
.
attention
.
shuffle_qkv_dims
()
layer
.
attention
.
shuffle_qkv_dims
()
layer
.
attention
.
query_key_value
.
transpose_weight
(
)
layer
.
attention
.
query_key_value
.
prepare_weights
(
load_in_8bit
)
layer
.
attention
.
dense
.
transpose_weight
(
)
layer
.
attention
.
dense
.
prepare_weights
(
load_in_8bit
)
layer
.
mlp
.
dense_h_to_4h
.
transpose_weight
(
)
layer
.
mlp
.
dense_h_to_4h
.
prepare_weights
(
load_in_8bit
)
layer
.
mlp
.
dense_4h_to_h
.
transpose_weight
(
)
layer
.
mlp
.
dense_4h_to_h
.
prepare_weights
(
load_in_8bit
)
@
classmethod
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
):
# Pop here as we will replace the layer in our own logic and don't want from_pretrained
# to do it for us
load_in_8bit
=
kwargs
.
pop
(
"load_in_8bit"
,
False
)
model
=
super
(
FlashGPTNeoXModel
,
cls
).
from_pretrained
(
model
=
super
(
FlashGPTNeoXModel
,
cls
).
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
pretrained_model_name_or_path
,
load_in_8bit
=
False
,
*
model_args
,
**
kwargs
)
)
model
.
post_load_weights
()
model
.
post_load_weights
(
load_in_8bit
)
return
model
return
model
def
forward
(
def
forward
(
...
@@ -653,16 +693,19 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
...
@@ -653,16 +693,19 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
)
)
def
post_load_weights
(
self
):
def
post_load_weights
(
self
,
load_in_8bit
=
False
):
self
.
gpt_neox
.
post_load_weights
()
self
.
gpt_neox
.
post_load_weights
(
load_in_8bit
)
self
.
embed_out
.
transpos
e_weight
()
self
.
embed_out
.
prepar
e_weight
s
()
@
classmethod
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
):
# Pop here as we will replace the layer in our own logic and don't want from_pretrained
# to do it for us
load_in_8bit
=
kwargs
.
pop
(
"load_in_8bit"
,
False
)
model
=
super
(
FlashGPTNeoXForCausalLM
,
cls
).
from_pretrained
(
model
=
super
(
FlashGPTNeoXForCausalLM
,
cls
).
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
pretrained_model_name_or_path
,
load_in_8bit
=
False
,
*
model_args
,
**
kwargs
)
)
model
.
post_load_weights
()
model
.
post_load_weights
(
load_in_8bit
)
return
model
return
model
def
forward
(
def
forward
(
...
...
server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
View file @
e14ae3b5
...
@@ -10,6 +10,12 @@ from transformers.activations import ACT2FN
...
@@ -10,6 +10,12 @@ from transformers.activations import ACT2FN
import
flash_attn_cuda
import
flash_attn_cuda
import
dropout_layer_norm
import
dropout_layer_norm
HAS_BITS_AND_BYTES
=
True
try
:
from
bitsandbytes.nn
import
Linear8bitLt
except
ImportError
as
e
:
HAS_BITS_AND_BYTES
=
False
class
FastLayerNorm
(
nn
.
LayerNorm
):
class
FastLayerNorm
(
nn
.
LayerNorm
):
def
forward
(
self
,
hidden_states
,
residual
=
None
):
def
forward
(
self
,
hidden_states
,
residual
=
None
):
...
@@ -57,14 +63,44 @@ class FastLinear(nn.Linear):
...
@@ -57,14 +63,44 @@ class FastLinear(nn.Linear):
dtype
=
None
,
dtype
=
None
,
)
->
None
:
)
->
None
:
super
(
FastLinear
,
self
).
__init__
(
in_features
,
out_features
,
bias
,
device
,
dtype
)
super
(
FastLinear
,
self
).
__init__
(
in_features
,
out_features
,
bias
,
device
,
dtype
)
self
.
quantized
=
False
self
.
bnb_linear
=
None
def
prepare_weights
(
self
,
quantize
:
bool
=
False
):
if
quantize
:
if
not
HAS_BITS_AND_BYTES
:
raise
ImportError
(
"bitsandbytes is not available on your machine either because it is not installed "
"or you don't have a GPU.
\n
"
"You can install it with `pip install bitsandbytes`."
)
def
transpose_weight
(
self
):
self
.
quantized
=
True
self
.
weight
=
nn
.
Parameter
(
self
.
weight
.
T
)
self
.
bnb_linear
=
Linear8bitLt
(
self
.
in_features
,
self
.
out_features
,
has_fp16_weights
=
False
,
threshold
=
6.0
,
bias
=
False
,
)
# Copy data to bnb_linear
self
.
bnb_linear
.
weight
.
data
=
self
.
weight
.
data
if
self
.
bias
is
not
None
:
self
.
bnb_linear
.
bias
=
nn
.
Parameter
(
self
.
bias
)
# Delete reference to data
self
.
weight
=
None
self
.
bias
=
None
else
:
self
.
weight
=
nn
.
Parameter
(
self
.
weight
.
T
)
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
bias
is
not
None
:
if
self
.
quantized
:
return
torch
.
addmm
(
self
.
bias
,
input
,
self
.
weight
)
return
self
.
bnb_linear
(
input
)
return
torch
.
matmul
(
input
,
self
.
weight
)
else
:
if
self
.
bias
is
not
None
:
return
torch
.
addmm
(
self
.
bias
,
input
,
self
.
weight
)
return
torch
.
matmul
(
input
,
self
.
weight
)
class
TensorParallelColumnLinear
(
FastLinear
):
class
TensorParallelColumnLinear
(
FastLinear
):
...
@@ -431,16 +467,16 @@ class FlashSantacoderModel(nn.Module):
...
@@ -431,16 +467,16 @@ class FlashSantacoderModel(nn.Module):
self
.
head_size
=
self
.
h
[
0
].
attn
.
head_size
self
.
head_size
=
self
.
h
[
0
].
attn
.
head_size
self
.
num_heads
=
self
.
h
[
0
].
attn
.
num_heads
self
.
num_heads
=
self
.
h
[
0
].
attn
.
num_heads
def
post_load_weights
(
self
):
def
post_load_weights
(
self
,
load_in_8bit
:
bool
=
False
):
if
self
.
tp_embeddings
:
if
self
.
tp_embeddings
:
self
.
wte
.
add_null_idx
()
self
.
wte
.
add_null_idx
()
self
.
wpe
.
add_null_idx
()
self
.
wpe
.
add_null_idx
()
for
layer
in
self
.
h
:
for
layer
in
self
.
h
:
layer
:
Block
layer
:
Block
layer
.
attn
.
c_attn
.
transpose_weight
(
)
layer
.
attn
.
c_attn
.
prepare_weights
(
load_in_8bit
)
layer
.
attn
.
c_proj
.
transpose_weight
(
)
layer
.
attn
.
c_proj
.
prepare_weights
(
load_in_8bit
)
layer
.
mlp
.
c_fc
.
transpose_weight
(
)
layer
.
mlp
.
c_fc
.
prepare_weights
(
load_in_8bit
)
layer
.
mlp
.
c_proj
.
transpose_weight
(
)
layer
.
mlp
.
c_proj
.
prepare_weights
(
load_in_8bit
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -508,9 +544,9 @@ class FlashSantacoderForCausalLM(nn.Module):
...
@@ -508,9 +544,9 @@ class FlashSantacoderForCausalLM(nn.Module):
else
:
else
:
self
.
lm_head
=
FastLinear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
)
self
.
lm_head
=
FastLinear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
)
def
post_load_weights
(
self
):
def
post_load_weights
(
self
,
load_in_8bit
:
bool
=
False
):
self
.
transformer
.
post_load_weights
()
self
.
transformer
.
post_load_weights
(
load_in_8bit
)
self
.
lm_head
.
transpos
e_weight
()
self
.
lm_head
.
prepar
e_weight
s
()
def
forward
(
def
forward
(
self
,
self
,
...
...
server/text_generation_server/models/flash_causal_lm.py
View file @
e14ae3b5
...
@@ -221,9 +221,6 @@ class FlashCausalLM(Model):
...
@@ -221,9 +221,6 @@ class FlashCausalLM(Model):
else
:
else
:
raise
NotImplementedError
(
"FlashCausalLM is only available on GPU"
)
raise
NotImplementedError
(
"FlashCausalLM is only available on GPU"
)
if
quantize
:
raise
NotImplementedError
(
"FlashCausalLM does not support quantization"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_id
,
revision
=
revision
,
padding_side
=
"left"
,
truncation_side
=
"left"
model_id
,
revision
=
revision
,
padding_side
=
"left"
,
truncation_side
=
"left"
)
)
...
@@ -232,9 +229,10 @@ class FlashCausalLM(Model):
...
@@ -232,9 +229,10 @@ class FlashCausalLM(Model):
model_id
,
model_id
,
revision
=
revision
,
revision
=
revision
,
torch_dtype
=
dtype
,
torch_dtype
=
dtype
,
load_in_8bit
=
quantize
,
)
)
.
eval
()
.
eval
()
.
cuda
(
)
.
to
(
device
)
)
)
super
(
FlashCausalLM
,
self
).
__init__
(
super
(
FlashCausalLM
,
self
).
__init__
(
...
...
server/text_generation_server/models/flash_llama.py
View file @
e14ae3b5
...
@@ -35,9 +35,6 @@ class FlashLlama(FlashCausalLM):
...
@@ -35,9 +35,6 @@ class FlashLlama(FlashCausalLM):
else
:
else
:
raise
NotImplementedError
(
"FlashLlama is only available on GPU"
)
raise
NotImplementedError
(
"FlashLlama is only available on GPU"
)
if
quantize
:
raise
NotImplementedError
(
"FlashLlama does not support quantization"
)
tokenizer
=
LlamaTokenizer
.
from_pretrained
(
tokenizer
=
LlamaTokenizer
.
from_pretrained
(
model_id
,
model_id
,
revision
=
revision
,
revision
=
revision
,
...
@@ -61,8 +58,8 @@ class FlashLlama(FlashCausalLM):
...
@@ -61,8 +58,8 @@ class FlashLlama(FlashCausalLM):
with
init_empty_weights
():
with
init_empty_weights
():
model
=
FlashLlamaForCausalLM
(
config
)
model
=
FlashLlamaForCausalLM
(
config
)
self
.
load_weights
(
model
,
filenames
,
device
,
dtype
)
self
.
load_weights
(
model
,
filenames
,
quantize
,
device
,
dtype
)
self
.
model
=
model
.
eval
()
self
.
model
=
model
.
eval
()
.
to
(
device
)
super
(
FlashCausalLM
,
self
).
__init__
(
super
(
FlashCausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
...
@@ -73,13 +70,14 @@ class FlashLlama(FlashCausalLM):
...
@@ -73,13 +70,14 @@ class FlashLlama(FlashCausalLM):
def
load_weights
(
def
load_weights
(
model
,
model
,
filenames
:
List
[
Path
],
filenames
:
List
[
Path
],
quantize
:
bool
,
device
:
torch
.
device
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
):
):
for
filename
in
filenames
:
for
filename
in
filenames
:
state_dict
=
torch
.
load
(
filename
,
map_location
=
"cpu"
)
state_dict
=
torch
.
load
(
filename
,
map_location
=
"cpu"
)
for
key
,
value
in
state_dict
.
items
():
for
key
,
value
in
state_dict
.
items
():
value
=
value
.
to
(
device
).
to
(
dtype
)
value
=
value
.
to
(
device
if
not
quantize
else
"cpu"
).
to
(
dtype
)
layer_name
=
"."
.
join
(
key
.
split
(
"."
)[:
4
])
layer_name
=
"."
.
join
(
key
.
split
(
"."
)[:
4
])
...
@@ -139,7 +137,7 @@ class FlashLlama(FlashCausalLM):
...
@@ -139,7 +137,7 @@ class FlashLlama(FlashCausalLM):
del
value
del
value
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
model
.
post_load_weights
()
model
.
post_load_weights
(
quantize
)
class
FlashLlamaSharded
(
FlashLlama
):
class
FlashLlamaSharded
(
FlashLlama
):
...
@@ -154,9 +152,6 @@ class FlashLlamaSharded(FlashLlama):
...
@@ -154,9 +152,6 @@ class FlashLlamaSharded(FlashLlama):
else
:
else
:
raise
NotImplementedError
(
"FlashLlama is only available on GPU"
)
raise
NotImplementedError
(
"FlashLlama is only available on GPU"
)
if
quantize
:
raise
NotImplementedError
(
"FlashLlama does not support quantization"
)
tokenizer
=
LlamaTokenizer
.
from_pretrained
(
tokenizer
=
LlamaTokenizer
.
from_pretrained
(
model_id
,
model_id
,
revision
=
revision
,
revision
=
revision
,
...
@@ -185,7 +180,7 @@ class FlashLlamaSharded(FlashLlama):
...
@@ -185,7 +180,7 @@ class FlashLlamaSharded(FlashLlama):
rank
=
self
.
rank
,
rank
=
self
.
rank
,
world_size
=
self
.
world_size
,
world_size
=
self
.
world_size
,
)
)
self
.
model
=
model
.
eval
()
self
.
model
=
model
.
eval
()
.
to
(
device
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
super
(
FlashCausalLM
,
self
).
__init__
(
super
(
FlashCausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
...
@@ -300,4 +295,4 @@ class FlashLlamaSharded(FlashLlama):
...
@@ -300,4 +295,4 @@ class FlashLlamaSharded(FlashLlama):
else
:
else
:
module
.
_buffers
[
param_name
]
=
tensor
module
.
_buffers
[
param_name
]
=
tensor
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
model
.
post_load_weights
()
model
.
post_load_weights
(
quantize
)
server/text_generation_server/models/flash_neox.py
View file @
e14ae3b5
...
@@ -41,9 +41,6 @@ class FlashNeoXSharded(FlashNeoX):
...
@@ -41,9 +41,6 @@ class FlashNeoXSharded(FlashNeoX):
else
:
else
:
raise
NotImplementedError
(
"FlashNeoX is only available on GPU"
)
raise
NotImplementedError
(
"FlashNeoX is only available on GPU"
)
if
quantize
:
raise
NotImplementedError
(
"FlashNeoX does not support quantization"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_id
,
revision
=
revision
,
padding_side
=
"left"
,
truncation_side
=
"left"
model_id
,
revision
=
revision
,
padding_side
=
"left"
,
truncation_side
=
"left"
)
)
...
@@ -63,13 +60,13 @@ class FlashNeoXSharded(FlashNeoX):
...
@@ -63,13 +60,13 @@ class FlashNeoXSharded(FlashNeoX):
self
.
load_weights
(
self
.
load_weights
(
model
,
model
,
filenames
,
filenames
,
quantize
=
quantize
,
device
=
device
,
device
=
device
,
dtype
=
dtype
,
dtype
=
dtype
,
rank
=
self
.
rank
,
rank
=
self
.
rank
,
world_size
=
self
.
world_size
,
world_size
=
self
.
world_size
,
)
)
model
.
post_load_weights
()
self
.
model
=
model
.
eval
().
to
(
device
)
self
.
model
=
model
.
eval
()
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
super
(
FlashCausalLM
,
self
).
__init__
(
super
(
FlashCausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
...
@@ -80,6 +77,7 @@ class FlashNeoXSharded(FlashNeoX):
...
@@ -80,6 +77,7 @@ class FlashNeoXSharded(FlashNeoX):
def
load_weights
(
def
load_weights
(
model
,
model
,
filenames
:
List
[
str
],
filenames
:
List
[
str
],
quantize
:
bool
,
device
:
torch
.
device
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
rank
:
int
,
rank
:
int
,
...
@@ -87,7 +85,9 @@ class FlashNeoXSharded(FlashNeoX):
...
@@ -87,7 +85,9 @@ class FlashNeoXSharded(FlashNeoX):
):
):
parameters
=
dict
(
model
.
named_parameters
())
parameters
=
dict
(
model
.
named_parameters
())
for
file
in
filenames
:
for
file
in
filenames
:
with
safe_open
(
file
,
framework
=
"pt"
,
device
=
str
(
device
))
as
f
:
with
safe_open
(
file
,
framework
=
"pt"
,
device
=
str
(
device
)
if
not
quantize
else
"cpu"
)
as
f
:
for
name
in
f
.
keys
():
for
name
in
f
.
keys
():
module_name
,
param_name
=
name
.
rsplit
(
"."
,
1
)
module_name
,
param_name
=
name
.
rsplit
(
"."
,
1
)
module
=
model
.
get_submodule
(
module_name
)
module
=
model
.
get_submodule
(
module_name
)
...
@@ -146,3 +146,4 @@ class FlashNeoXSharded(FlashNeoX):
...
@@ -146,3 +146,4 @@ class FlashNeoXSharded(FlashNeoX):
module
.
_parameters
[
param_name
]
=
tensor
module
.
_parameters
[
param_name
]
=
tensor
else
:
else
:
module
.
_buffers
[
param_name
]
=
tensor
module
.
_buffers
[
param_name
]
=
tensor
model
.
post_load_weights
(
quantize
)
server/text_generation_server/models/flash_santacoder.py
View file @
e14ae3b5
...
@@ -34,9 +34,6 @@ class FlashSantacoder(FlashCausalLM):
...
@@ -34,9 +34,6 @@ class FlashSantacoder(FlashCausalLM):
else
:
else
:
raise
NotImplementedError
(
"FlashSantacoder is only available on GPU"
)
raise
NotImplementedError
(
"FlashSantacoder is only available on GPU"
)
if
quantize
:
raise
NotImplementedError
(
"FlashSantacoder does not support quantization"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_id
,
revision
=
revision
,
padding_side
=
"left"
,
truncation_side
=
"left"
model_id
,
revision
=
revision
,
padding_side
=
"left"
,
truncation_side
=
"left"
)
)
...
@@ -58,9 +55,14 @@ class FlashSantacoder(FlashCausalLM):
...
@@ -58,9 +55,14 @@ class FlashSantacoder(FlashCausalLM):
model
=
FlashSantacoderForCausalLM
(
config
)
model
=
FlashSantacoderForCausalLM
(
config
)
self
.
load_weights
(
self
.
load_weights
(
model
,
filenames
,
device
,
dtype
,
config
.
architectures
[
0
].
startswith
(
"GPT2"
)
model
,
filenames
,
quantize
,
device
,
dtype
,
config
.
architectures
[
0
].
startswith
(
"GPT2"
),
)
)
self
.
model
=
model
.
eval
()
self
.
model
=
model
.
eval
()
.
to
(
device
)
super
(
FlashCausalLM
,
self
).
__init__
(
super
(
FlashCausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
device
=
device
,
decode_buffer
=
1
tokenizer
=
tokenizer
,
device
=
device
,
decode_buffer
=
1
...
@@ -70,6 +72,7 @@ class FlashSantacoder(FlashCausalLM):
...
@@ -70,6 +72,7 @@ class FlashSantacoder(FlashCausalLM):
def
load_weights
(
def
load_weights
(
model
:
FlashSantacoderForCausalLM
,
model
:
FlashSantacoderForCausalLM
,
filenames
:
List
[
Path
],
filenames
:
List
[
Path
],
quantize
:
bool
,
device
:
torch
.
device
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
transpose
:
bool
,
transpose
:
bool
,
...
@@ -77,7 +80,7 @@ class FlashSantacoder(FlashCausalLM):
...
@@ -77,7 +80,7 @@ class FlashSantacoder(FlashCausalLM):
for
filename
in
filenames
:
for
filename
in
filenames
:
state_dict
=
torch
.
load
(
filename
,
map_location
=
"cpu"
)
state_dict
=
torch
.
load
(
filename
,
map_location
=
"cpu"
)
for
key
,
value
in
state_dict
.
items
():
for
key
,
value
in
state_dict
.
items
():
value
=
value
.
to
(
device
).
to
(
dtype
)
value
=
value
.
to
(
device
if
not
quantize
else
"cpu"
).
to
(
dtype
)
layer_name
=
"."
.
join
(
key
.
split
(
"."
)[:
4
])
layer_name
=
"."
.
join
(
key
.
split
(
"."
)[:
4
])
...
@@ -152,7 +155,7 @@ class FlashSantacoder(FlashCausalLM):
...
@@ -152,7 +155,7 @@ class FlashSantacoder(FlashCausalLM):
del
value
del
value
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
model
.
post_load_weights
()
model
.
post_load_weights
(
quantize
)
def
decode
(
self
,
generated_ids
:
List
[
int
])
->
str
:
def
decode
(
self
,
generated_ids
:
List
[
int
])
->
str
:
# Do not skip special tokens as they are used for custom parsing rules of the generated text
# Do not skip special tokens as they are used for custom parsing rules of the generated text
...
@@ -173,11 +176,6 @@ class FlashSantacoderSharded(FlashSantacoder):
...
@@ -173,11 +176,6 @@ class FlashSantacoderSharded(FlashSantacoder):
else
:
else
:
raise
NotImplementedError
(
"FlashSantacoderSharded is only available on GPU"
)
raise
NotImplementedError
(
"FlashSantacoderSharded is only available on GPU"
)
if
quantize
:
raise
NotImplementedError
(
"FlashSantacoderSharded does not support quantization"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_id
,
revision
=
revision
,
padding_side
=
"left"
,
truncation_side
=
"left"
model_id
,
revision
=
revision
,
padding_side
=
"left"
,
truncation_side
=
"left"
)
)
...
@@ -197,13 +195,14 @@ class FlashSantacoderSharded(FlashSantacoder):
...
@@ -197,13 +195,14 @@ class FlashSantacoderSharded(FlashSantacoder):
self
.
load_weights
(
self
.
load_weights
(
model
,
model
,
filenames
,
filenames
,
quantize
=
quantize
,
device
=
device
,
device
=
device
,
dtype
=
dtype
,
dtype
=
dtype
,
rank
=
self
.
rank
,
rank
=
self
.
rank
,
world_size
=
self
.
world_size
,
world_size
=
self
.
world_size
,
transpose
=
config
.
architectures
[
0
].
startswith
(
"GPT2"
),
transpose
=
config
.
architectures
[
0
].
startswith
(
"GPT2"
),
)
)
self
.
model
=
model
.
eval
()
self
.
model
=
model
.
eval
()
.
to
(
device
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
super
(
FlashCausalLM
,
self
).
__init__
(
super
(
FlashCausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
...
@@ -214,6 +213,7 @@ class FlashSantacoderSharded(FlashSantacoder):
...
@@ -214,6 +213,7 @@ class FlashSantacoderSharded(FlashSantacoder):
def
load_weights
(
def
load_weights
(
model
,
model
,
filenames
:
List
[
str
],
filenames
:
List
[
str
],
quantize
:
bool
,
device
:
torch
.
device
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
rank
:
int
,
rank
:
int
,
...
@@ -221,7 +221,9 @@ class FlashSantacoderSharded(FlashSantacoder):
...
@@ -221,7 +221,9 @@ class FlashSantacoderSharded(FlashSantacoder):
transpose
:
bool
,
transpose
:
bool
,
):
):
for
file
in
filenames
:
for
file
in
filenames
:
with
safe_open
(
file
,
framework
=
"pt"
,
device
=
str
(
device
))
as
f
:
with
safe_open
(
file
,
framework
=
"pt"
,
device
=
str
(
device
)
if
not
quantize
else
"cpu"
)
as
f
:
for
key
in
f
.
keys
():
for
key
in
f
.
keys
():
slice_
=
f
.
get_slice
(
key
)
slice_
=
f
.
get_slice
(
key
)
...
@@ -363,4 +365,4 @@ class FlashSantacoderSharded(FlashSantacoder):
...
@@ -363,4 +365,4 @@ class FlashSantacoderSharded(FlashSantacoder):
else
:
else
:
module
.
_buffers
[
param_name
]
=
tensor
module
.
_buffers
[
param_name
]
=
tensor
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
model
.
post_load_weights
()
model
.
post_load_weights
(
quantize
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment