Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
47954b81
Unverified
Commit
47954b81
authored
Sep 27, 2023
by
OlivierDehaene
Committed by
GitHub
Sep 27, 2023
Browse files
feat: format code (#1070)
parent
b32e9ce9
Changes
28
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
131 additions
and
62 deletions
+131
-62
server/text_generation_server/models/seq2seq_lm.py
server/text_generation_server/models/seq2seq_lm.py
+4
-2
server/text_generation_server/server.py
server/text_generation_server/server.py
+17
-5
server/text_generation_server/utils/awq/quantize/qmodule.py
server/text_generation_server/utils/awq/quantize/qmodule.py
+7
-5
server/text_generation_server/utils/gptq/quantize.py
server/text_generation_server/utils/gptq/quantize.py
+4
-2
server/text_generation_server/utils/layers.py
server/text_generation_server/utils/layers.py
+76
-26
server/text_generation_server/utils/peft.py
server/text_generation_server/utils/peft.py
+2
-4
server/text_generation_server/utils/tokens.py
server/text_generation_server/utils/tokens.py
+1
-1
server/text_generation_server/utils/weights.py
server/text_generation_server/utils/weights.py
+20
-17
No files found.
server/text_generation_server/models/seq2seq_lm.py
View file @
47954b81
...
@@ -712,9 +712,11 @@ class Seq2SeqLM(Model):
...
@@ -712,9 +712,11 @@ class Seq2SeqLM(Model):
# Decode all tokens
# Decode all tokens
output_text
,
_
,
_
=
self
.
decode_token
(
output_text
,
_
,
_
=
self
.
decode_token
(
all_decoder_input_ids
,
all_decoder_input_ids
,
prefix_offset
=
len
(
all_decoder_input_ids
)
-
decoder_input_length
-
1
,
prefix_offset
=
len
(
all_decoder_input_ids
)
-
decoder_input_length
-
1
,
read_offset
=
len
(
all_decoder_input_ids
)
-
decoder_input_length
,
read_offset
=
len
(
all_decoder_input_ids
)
-
decoder_input_length
,
skip_special_tokens
=
True
skip_special_tokens
=
True
,
)
)
# Get seed
# Get seed
...
...
server/text_generation_server/server.py
View file @
47954b81
...
@@ -16,6 +16,7 @@ from text_generation_server.pb import generate_pb2_grpc, generate_pb2
...
@@ -16,6 +16,7 @@ from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from
text_generation_server.tracing
import
UDSOpenTelemetryAioServerInterceptor
from
text_generation_server.tracing
import
UDSOpenTelemetryAioServerInterceptor
from
text_generation_server.models.idefics_causal_lm
import
IdeficsCausalLMBatch
from
text_generation_server.models.idefics_causal_lm
import
IdeficsCausalLMBatch
class
TextGenerationService
(
generate_pb2_grpc
.
TextGenerationServiceServicer
):
class
TextGenerationService
(
generate_pb2_grpc
.
TextGenerationServiceServicer
):
def
__init__
(
self
,
model
:
Model
,
cache
:
Cache
,
server_urls
:
List
[
str
]):
def
__init__
(
self
,
model
:
Model
,
cache
:
Cache
,
server_urls
:
List
[
str
]):
self
.
cache
=
cache
self
.
cache
=
cache
...
@@ -26,7 +27,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
...
@@ -26,7 +27,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
# Force inference mode for the lifetime of TextGenerationService
# Force inference mode for the lifetime of TextGenerationService
self
.
_inference_mode_raii_guard
=
torch
.
_C
.
_InferenceMode
(
True
)
self
.
_inference_mode_raii_guard
=
torch
.
_C
.
_InferenceMode
(
True
)
async
def
Info
(
self
,
request
,
context
):
async
def
Info
(
self
,
request
,
context
):
return
self
.
model
.
info
return
self
.
model
.
info
...
@@ -55,9 +55,15 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
...
@@ -55,9 +55,15 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
return
generate_pb2
.
FilterBatchResponse
(
batch
=
filtered_batch
.
to_pb
())
return
generate_pb2
.
FilterBatchResponse
(
batch
=
filtered_batch
.
to_pb
())
async
def
Warmup
(
self
,
request
,
context
):
async
def
Warmup
(
self
,
request
,
context
):
if
self
.
model
.
batch_type
==
IdeficsCausalLMBatch
:
#Hack, i would rather use kwargs in the `from_pb` call
if
(
self
.
model
.
batch_type
==
IdeficsCausalLMBatch
):
# Hack, i would rather use kwargs in the `from_pb` call
batch
=
self
.
model
.
batch_type
.
from_pb
(
batch
=
self
.
model
.
batch_type
.
from_pb
(
request
.
batch
,
self
.
model
.
tokenizer
,
self
.
model
.
processor
,
self
.
model
.
dtype
,
self
.
model
.
device
request
.
batch
,
self
.
model
.
tokenizer
,
self
.
model
.
processor
,
self
.
model
.
dtype
,
self
.
model
.
device
,
)
)
else
:
else
:
batch
=
self
.
model
.
batch_type
.
from_pb
(
batch
=
self
.
model
.
batch_type
.
from_pb
(
...
@@ -70,9 +76,15 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
...
@@ -70,9 +76,15 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
)
)
async
def
Prefill
(
self
,
request
,
context
):
async
def
Prefill
(
self
,
request
,
context
):
if
self
.
model
.
batch_type
==
IdeficsCausalLMBatch
:
#Hack, i would rather use kwargs in the `from_pb` call
if
(
self
.
model
.
batch_type
==
IdeficsCausalLMBatch
):
# Hack, i would rather use kwargs in the `from_pb` call
batch
=
self
.
model
.
batch_type
.
from_pb
(
batch
=
self
.
model
.
batch_type
.
from_pb
(
request
.
batch
,
self
.
model
.
tokenizer
,
self
.
model
.
processor
,
self
.
model
.
dtype
,
self
.
model
.
device
request
.
batch
,
self
.
model
.
tokenizer
,
self
.
model
.
processor
,
self
.
model
.
dtype
,
self
.
model
.
device
,
)
)
else
:
else
:
batch
=
self
.
model
.
batch_type
.
from_pb
(
batch
=
self
.
model
.
batch_type
.
from_pb
(
...
...
server/text_generation_server/utils/awq/quantize/qmodule.py
View file @
47954b81
...
@@ -42,7 +42,9 @@ class WQLinear(nn.Module):
...
@@ -42,7 +42,9 @@ class WQLinear(nn.Module):
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
out_shape
=
x
.
shape
[:
-
1
]
+
(
self
.
out_features
,
)
out_shape
=
x
.
shape
[:
-
1
]
+
(
self
.
out_features
,)
out
=
awq_inference_engine
.
gemm_forward_cuda
(
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
]),
self
.
qweight
,
self
.
scales
,
self
.
qzeros
,
8
)
out
=
awq_inference_engine
.
gemm_forward_cuda
(
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
]),
self
.
qweight
,
self
.
scales
,
self
.
qzeros
,
8
)
out
=
out
+
self
.
bias
if
self
.
bias
is
not
None
else
out
out
=
out
+
self
.
bias
if
self
.
bias
is
not
None
else
out
return
out
.
reshape
(
out_shape
)
return
out
.
reshape
(
out_shape
)
server/text_generation_server/utils/gptq/quantize.py
View file @
47954b81
...
@@ -578,7 +578,9 @@ def get_c4_new(nsamples, seed, seqlen, model_id, trust_remote_code):
...
@@ -578,7 +578,9 @@ def get_c4_new(nsamples, seed, seqlen, model_id, trust_remote_code):
return
trainloader
,
valenc
return
trainloader
,
valenc
def
get_loaders
(
name
,
nsamples
=
128
,
seed
=
0
,
seqlen
=
2048
,
model_id
=
""
,
trust_remote_code
=
False
):
def
get_loaders
(
name
,
nsamples
=
128
,
seed
=
0
,
seqlen
=
2048
,
model_id
=
""
,
trust_remote_code
=
False
):
if
"wikitext2"
in
name
:
if
"wikitext2"
in
name
:
return
get_wikitext2
(
nsamples
,
seed
,
seqlen
,
model_id
,
trust_remote_code
)
return
get_wikitext2
(
nsamples
,
seed
,
seqlen
,
model_id
,
trust_remote_code
)
if
"ptb"
in
name
:
if
"ptb"
in
name
:
...
@@ -927,7 +929,7 @@ def quantize(
...
@@ -927,7 +929,7 @@ def quantize(
seed
=
seed
,
seed
=
seed
,
model_id
=
model_id
,
model_id
=
model_id
,
seqlen
=
model
.
seqlen
,
seqlen
=
model
.
seqlen
,
trust_remote_code
=
trust_remote_code
trust_remote_code
=
trust_remote_code
,
)
)
tick
=
time
.
time
()
tick
=
time
.
time
()
...
...
server/text_generation_server/utils/layers.py
View file @
47954b81
...
@@ -38,6 +38,7 @@ if os.getenv("DISABLE_EXLLAMA") == "True":
...
@@ -38,6 +38,7 @@ if os.getenv("DISABLE_EXLLAMA") == "True":
elif
CAN_EXLLAMA
:
elif
CAN_EXLLAMA
:
try
:
try
:
from
text_generation_server.utils.gptq.exllama
import
Ex4bitLinear
from
text_generation_server.utils.gptq.exllama
import
Ex4bitLinear
HAS_EXLLAMA
=
True
HAS_EXLLAMA
=
True
except
ImportError
:
except
ImportError
:
pass
pass
...
@@ -47,6 +48,7 @@ from typing import Optional
...
@@ -47,6 +48,7 @@ from typing import Optional
HAS_EETQ
=
False
HAS_EETQ
=
False
try
:
try
:
from
EETQ
import
quant_weights
,
w8_a16_gemm
from
EETQ
import
quant_weights
,
w8_a16_gemm
HAS_EETQ
=
True
HAS_EETQ
=
True
except
ImportError
:
except
ImportError
:
pass
pass
...
@@ -74,12 +76,18 @@ def load_layer_norm_no_bias(cls, prefix, weights, eps):
...
@@ -74,12 +76,18 @@ def load_layer_norm_no_bias(cls, prefix, weights, eps):
ln
.
bias
=
None
ln
.
bias
=
None
return
ln
return
ln
@
classmethod
@
classmethod
def
load_conv2d
(
cls
,
prefix
,
weights
,
in_channels
,
out_channels
,
kernel_size
,
stride
):
def
load_conv2d
(
cls
,
prefix
,
weights
,
in_channels
,
out_channels
,
kernel_size
,
stride
):
weight
=
weights
.
get_tensor
(
f
"
{
prefix
}
.weight"
)
weight
=
weights
.
get_tensor
(
f
"
{
prefix
}
.weight"
)
bias
=
weights
.
get_tensor
(
f
"
{
prefix
}
.bias"
)
bias
=
weights
.
get_tensor
(
f
"
{
prefix
}
.bias"
)
with
init_empty_weights
():
with
init_empty_weights
():
conv2d
=
cls
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
)
conv2d
=
cls
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
)
conv2d
.
weight
=
nn
.
Parameter
(
weight
)
conv2d
.
weight
=
nn
.
Parameter
(
weight
)
conv2d
.
bias
=
nn
.
Parameter
(
bias
)
conv2d
.
bias
=
nn
.
Parameter
(
bias
)
...
@@ -87,10 +95,17 @@ def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, st
...
@@ -87,10 +95,17 @@ def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, st
@
classmethod
@
classmethod
def
load_conv2d_no_bias
(
cls
,
prefix
,
weights
,
in_channels
,
out_channels
,
kernel_size
,
stride
):
def
load_conv2d_no_bias
(
cls
,
prefix
,
weights
,
in_channels
,
out_channels
,
kernel_size
,
stride
):
weight
=
weights
.
get_tensor
(
f
"
{
prefix
}
.weight"
)
weight
=
weights
.
get_tensor
(
f
"
{
prefix
}
.weight"
)
with
init_empty_weights
():
with
init_empty_weights
():
conv2d
=
cls
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
)
conv2d
=
cls
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
)
conv2d
.
weight
=
nn
.
Parameter
(
weight
)
conv2d
.
weight
=
nn
.
Parameter
(
weight
)
conv2d
.
bias
=
None
conv2d
.
bias
=
None
...
@@ -215,7 +230,10 @@ class Linear4bit(nn.Module):
...
@@ -215,7 +230,10 @@ class Linear4bit(nn.Module):
def
__init__
(
self
,
weight
,
bias
,
quant_type
):
def
__init__
(
self
,
weight
,
bias
,
quant_type
):
super
().
__init__
()
super
().
__init__
()
self
.
weight
=
Params4bit
(
self
.
weight
=
Params4bit
(
weight
.
data
,
requires_grad
=
False
,
compress_statistics
=
True
,
quant_type
=
quant_type
weight
.
data
,
requires_grad
=
False
,
compress_statistics
=
True
,
quant_type
=
quant_type
,
)
)
self
.
compute_dtype
=
None
self
.
compute_dtype
=
None
self
.
weight
.
cuda
(
weight
.
device
)
self
.
weight
.
cuda
(
weight
.
device
)
...
@@ -246,7 +264,10 @@ class Linear4bit(nn.Module):
...
@@ -246,7 +264,10 @@ class Linear4bit(nn.Module):
@
lru_cache
(
1
)
@
lru_cache
(
1
)
def
warn_deprecate_bnb
():
def
warn_deprecate_bnb
():
logger
.
warning
(
"Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce"
)
logger
.
warning
(
"Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce"
)
def
get_linear
(
weight
,
bias
,
quantize
):
def
get_linear
(
weight
,
bias
,
quantize
):
if
quantize
is
None
:
if
quantize
is
None
:
...
@@ -255,7 +276,9 @@ def get_linear(weight, bias, quantize):
...
@@ -255,7 +276,9 @@ def get_linear(weight, bias, quantize):
if
HAS_EETQ
:
if
HAS_EETQ
:
linear
=
EETQLinear
(
weight
,
bias
)
linear
=
EETQLinear
(
weight
,
bias
)
else
:
else
:
raise
ImportError
(
"Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
)
raise
ImportError
(
"Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
)
elif
quantize
==
"bitsandbytes"
:
elif
quantize
==
"bitsandbytes"
:
warn_deprecate_bnb
()
warn_deprecate_bnb
()
linear
=
Linear8bitLt
(
linear
=
Linear8bitLt
(
...
@@ -305,7 +328,14 @@ def get_linear(weight, bias, quantize):
...
@@ -305,7 +328,14 @@ def get_linear(weight, bias, quantize):
raise
NotImplementedError
(
raise
NotImplementedError
(
f
"The passed weight is not `awq` compatible, loader needs to be updated."
f
"The passed weight is not `awq` compatible, loader needs to be updated."
)
)
linear
=
WQLinear
(
w_bit
=
bits
,
group_size
=
groupsize
,
qweight
=
qweight
,
qzeros
=
qzeros
,
scales
=
scales
,
bias
=
bias
is
not
None
)
linear
=
WQLinear
(
w_bit
=
bits
,
group_size
=
groupsize
,
qweight
=
qweight
,
qzeros
=
qzeros
,
scales
=
scales
,
bias
=
bias
is
not
None
,
)
else
:
else
:
raise
NotImplementedError
(
f
"Quantization `
{
quantize
}
` is not implemented yet."
)
raise
NotImplementedError
(
f
"Quantization `
{
quantize
}
` is not implemented yet."
)
return
linear
return
linear
...
@@ -392,9 +422,7 @@ class TensorParallelColumnLinear(SuperLayer):
...
@@ -392,9 +422,7 @@ class TensorParallelColumnLinear(SuperLayer):
@
classmethod
@
classmethod
def
load_qkv
(
cls
,
config
,
prefix
:
str
,
weights
,
bias
:
bool
):
def
load_qkv
(
cls
,
config
,
prefix
:
str
,
weights
,
bias
:
bool
):
"""Specific method when the QKV was joined after the fact"""
"""Specific method when the QKV was joined after the fact"""
weight
=
weights
.
get_weights_col_packed_qkv
(
weight
=
weights
.
get_weights_col_packed_qkv
(
prefix
,
quantize
=
config
.
quantize
)
prefix
,
quantize
=
config
.
quantize
)
if
bias
:
if
bias
:
raise
NotImplementedError
(
"packed_qkv only implemented for baichuan"
)
raise
NotImplementedError
(
"packed_qkv only implemented for baichuan"
)
else
:
else
:
...
@@ -530,14 +558,16 @@ try:
...
@@ -530,14 +558,16 @@ try:
def
_create_inv_freq
(
dim
,
base
,
device
):
def
_create_inv_freq
(
dim
,
base
,
device
):
inv_freq
=
1.0
/
(
inv_freq
=
1.0
/
(
base
base
**
(
torch
.
arange
(
0
,
dim
,
2
,
device
=
device
,
dtype
=
torch
.
float32
)
/
dim
)
**
(
torch
.
arange
(
0
,
dim
,
2
,
device
=
device
,
dtype
=
torch
.
float32
)
/
dim
)
)
)
return
inv_freq
return
inv_freq
def
_get_rope_config
(
config
):
def
_get_rope_config
(
config
):
if
os
.
getenv
(
"ROPE_SCALING"
,
None
)
is
not
None
:
if
os
.
getenv
(
"ROPE_SCALING"
,
None
)
is
not
None
:
rope_scaling
=
{
"type"
:
os
.
environ
[
"ROPE_SCALING"
],
"factor"
:
float
(
os
.
environ
[
"ROPE_FACTOR"
])}
rope_scaling
=
{
"type"
:
os
.
environ
[
"ROPE_SCALING"
],
"factor"
:
float
(
os
.
environ
[
"ROPE_FACTOR"
]),
}
return
rope_scaling
return
rope_scaling
return
getattr
(
config
,
"rope_scaling"
,
None
)
return
getattr
(
config
,
"rope_scaling"
,
None
)
...
@@ -563,9 +593,17 @@ try:
...
@@ -563,9 +593,17 @@ try:
if
rope_scaling
[
"type"
]
==
"linear"
:
if
rope_scaling
[
"type"
]
==
"linear"
:
pass
pass
elif
rope_scaling
[
"type"
]
==
"dynamic"
:
elif
rope_scaling
[
"type"
]
==
"dynamic"
:
return
DynamicPositionRotaryEmbedding
(
dim
=
dim
,
max_position_embeddings
=
config
.
max_position_embeddings
,
base
=
base
,
device
=
inv_freq
.
device
,
scaling_factor
=
scaling_factor
)
return
DynamicPositionRotaryEmbedding
(
dim
=
dim
,
max_position_embeddings
=
config
.
max_position_embeddings
,
base
=
base
,
device
=
inv_freq
.
device
,
scaling_factor
=
scaling_factor
,
)
else
:
else
:
raise
NotImplementedError
(
f
"rope scaling type
{
rope_scaling
[
'type'
]
}
is not implemented or invalid"
)
raise
NotImplementedError
(
f
"rope scaling type
{
rope_scaling
[
'type'
]
}
is not implemented or invalid"
)
return
cls
(
inv_freq
,
scaling_factor
)
return
cls
(
inv_freq
,
scaling_factor
)
@
classmethod
@
classmethod
...
@@ -583,9 +621,17 @@ try:
...
@@ -583,9 +621,17 @@ try:
if
rope_scaling
[
"type"
]
==
"linear"
:
if
rope_scaling
[
"type"
]
==
"linear"
:
pass
pass
elif
rope_scaling
[
"type"
]
==
"dynamic"
:
elif
rope_scaling
[
"type"
]
==
"dynamic"
:
return
DynamicPositionRotaryEmbedding
(
dim
=
2
*
inv_freq
.
shape
[
0
],
max_position_embeddings
=
config
.
max_position_embeddings
,
base
=
10000.0
,
device
=
inv_freq
.
device
,
scaling_factor
=
scaling_factor
)
return
DynamicPositionRotaryEmbedding
(
dim
=
2
*
inv_freq
.
shape
[
0
],
max_position_embeddings
=
config
.
max_position_embeddings
,
base
=
10000.0
,
device
=
inv_freq
.
device
,
scaling_factor
=
scaling_factor
,
)
else
:
else
:
raise
NotImplementedError
(
f
"rope scaling type
{
rope_scaling
[
'type'
]
}
is not implemented or invalid"
)
raise
NotImplementedError
(
f
"rope scaling type
{
rope_scaling
[
'type'
]
}
is not implemented or invalid"
)
return
cls
(
inv_freq
,
scaling_factor
)
return
cls
(
inv_freq
,
scaling_factor
)
def
_update_cos_sin_cache
(
self
,
dtype
,
device
,
seqlen
):
def
_update_cos_sin_cache
(
self
,
dtype
,
device
,
seqlen
):
...
@@ -645,8 +691,13 @@ try:
...
@@ -645,8 +691,13 @@ try:
or
self
.
_cos_cached
.
dtype
!=
dtype
or
self
.
_cos_cached
.
dtype
!=
dtype
):
):
if
seqlen
>
self
.
max_position_embeddings
:
if
seqlen
>
self
.
max_position_embeddings
:
newbase
=
self
.
base
*
((
self
.
scaling_factor
*
seqlen
/
self
.
max_position_embeddings
)
-
(
self
.
scaling_factor
-
1
))
**
(
self
.
dim
/
(
self
.
dim
-
2
))
newbase
=
self
.
base
*
(
self
.
inv_freq
=
_create_inv_freq
(
self
.
dim
,
newbase
,
self
.
inv_freq
.
device
)
(
self
.
scaling_factor
*
seqlen
/
self
.
max_position_embeddings
)
-
(
self
.
scaling_factor
-
1
)
)
**
(
self
.
dim
/
(
self
.
dim
-
2
))
self
.
inv_freq
=
_create_inv_freq
(
self
.
dim
,
newbase
,
self
.
inv_freq
.
device
)
self
.
_seq_len_cached
=
seqlen
self
.
_seq_len_cached
=
seqlen
t
=
torch
.
arange
(
seqlen
,
device
=
device
,
dtype
=
self
.
inv_freq
.
dtype
)
t
=
torch
.
arange
(
seqlen
,
device
=
device
,
dtype
=
self
.
inv_freq
.
dtype
)
# Don't do einsum, it converts fp32 to fp16
# Don't do einsum, it converts fp32 to fp16
...
@@ -656,6 +707,5 @@ try:
...
@@ -656,6 +707,5 @@ try:
self
.
_cos_cached
=
torch
.
cos
(
freqs
).
to
(
dtype
)
self
.
_cos_cached
=
torch
.
cos
(
freqs
).
to
(
dtype
)
self
.
_sin_cached
=
torch
.
sin
(
freqs
).
to
(
dtype
)
self
.
_sin_cached
=
torch
.
sin
(
freqs
).
to
(
dtype
)
except
ImportError
:
except
ImportError
:
pass
pass
server/text_generation_server/utils/peft.py
View file @
47954b81
...
@@ -6,6 +6,7 @@ import torch
...
@@ -6,6 +6,7 @@ import torch
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
from
peft
import
AutoPeftModelForCausalLM
,
AutoPeftModelForSeq2SeqLM
from
peft
import
AutoPeftModelForCausalLM
,
AutoPeftModelForSeq2SeqLM
def
download_and_unload_peft
(
model_id
,
revision
,
trust_remote_code
):
def
download_and_unload_peft
(
model_id
,
revision
,
trust_remote_code
):
torch_dtype
=
torch
.
float16
torch_dtype
=
torch
.
float16
...
@@ -41,6 +42,3 @@ def download_and_unload_peft(model_id, revision, trust_remote_code):
...
@@ -41,6 +42,3 @@ def download_and_unload_peft(model_id, revision, trust_remote_code):
model
.
save_pretrained
(
cache_dir
,
safe_serialization
=
True
)
model
.
save_pretrained
(
cache_dir
,
safe_serialization
=
True
)
model
.
config
.
save_pretrained
(
cache_dir
)
model
.
config
.
save_pretrained
(
cache_dir
)
tokenizer
.
save_pretrained
(
cache_dir
)
tokenizer
.
save_pretrained
(
cache_dir
)
server/text_generation_server/utils/tokens.py
View file @
47954b81
server/text_generation_server/utils/weights.py
View file @
47954b81
...
@@ -62,7 +62,7 @@ class Weights:
...
@@ -62,7 +62,7 @@ class Weights:
def
get_shape
(
self
,
tensor_name
:
str
):
def
get_shape
(
self
,
tensor_name
:
str
):
return
self
.
_get_slice
(
tensor_name
).
get_shape
()
return
self
.
_get_slice
(
tensor_name
).
get_shape
()
def
get_tensor
(
self
,
tensor_name
:
str
,
to_device
=
True
):
def
get_tensor
(
self
,
tensor_name
:
str
,
to_device
=
True
):
filename
,
tensor_name
=
self
.
get_filename
(
tensor_name
)
filename
,
tensor_name
=
self
.
get_filename
(
tensor_name
)
f
=
self
.
_get_handle
(
filename
)
f
=
self
.
_get_handle
(
filename
)
tensor
=
f
.
get_tensor
(
tensor_name
)
tensor
=
f
.
get_tensor
(
tensor_name
)
...
@@ -110,7 +110,6 @@ class Weights:
...
@@ -110,7 +110,6 @@ class Weights:
),
f
"The choosen size
{
size
}
is not compatible with sharding on
{
world_size
}
shards"
),
f
"The choosen size
{
size
}
is not compatible with sharding on
{
world_size
}
shards"
return
self
.
get_partial_sharded
(
tensor_name
,
dim
)
return
self
.
get_partial_sharded
(
tensor_name
,
dim
)
def
_get_qweight
(
self
,
name
:
str
):
def
_get_qweight
(
self
,
name
:
str
):
slice_
=
self
.
_get_slice
(
name
)
slice_
=
self
.
_get_slice
(
name
)
total_size
=
slice_
.
get_shape
()[
1
]
total_size
=
slice_
.
get_shape
()[
1
]
...
@@ -119,14 +118,16 @@ class Weights:
...
@@ -119,14 +118,16 @@ class Weights:
world_size
=
self
.
process_group
.
size
()
world_size
=
self
.
process_group
.
size
()
rank
=
self
.
process_group
.
rank
()
rank
=
self
.
process_group
.
rank
()
assert
single_size
%
world_size
==
0
,
f
"Prepacked quantized qkv cannot be sharded across
{
world_size
}
shards"
assert
(
single_size
%
world_size
==
0
),
f
"Prepacked quantized qkv cannot be sharded across
{
world_size
}
shards"
block_size
=
single_size
//
world_size
block_size
=
single_size
//
world_size
start
=
rank
*
block_size
start
=
rank
*
block_size
stop
=
(
rank
+
1
)
*
block_size
stop
=
(
rank
+
1
)
*
block_size
q
=
slice_
[:,
start
:
stop
]
q
=
slice_
[:,
start
:
stop
]
k
=
slice_
[:,
start
+
single_size
:
stop
+
single_size
]
k
=
slice_
[:,
start
+
single_size
:
stop
+
single_size
]
v
=
slice_
[:,
start
+
2
*
single_size
:
stop
+
2
*
single_size
]
v
=
slice_
[:,
start
+
2
*
single_size
:
stop
+
2
*
single_size
]
weight
=
torch
.
cat
([
q
,
k
,
v
],
dim
=
1
)
weight
=
torch
.
cat
([
q
,
k
,
v
],
dim
=
1
)
weight
=
weight
.
to
(
device
=
self
.
device
)
weight
=
weight
.
to
(
device
=
self
.
device
)
return
weight
return
weight
...
@@ -161,14 +162,16 @@ class Weights:
...
@@ -161,14 +162,16 @@ class Weights:
world_size
=
self
.
process_group
.
size
()
world_size
=
self
.
process_group
.
size
()
rank
=
self
.
process_group
.
rank
()
rank
=
self
.
process_group
.
rank
()
assert
single_size
%
world_size
==
0
,
f
"Prepacked qkv cannot be sharded across
{
world_size
}
shards"
assert
(
single_size
%
world_size
==
0
),
f
"Prepacked qkv cannot be sharded across
{
world_size
}
shards"
block_size
=
single_size
//
world_size
block_size
=
single_size
//
world_size
start
=
rank
*
block_size
start
=
rank
*
block_size
stop
=
(
rank
+
1
)
*
block_size
stop
=
(
rank
+
1
)
*
block_size
q
=
slice_
[
start
:
stop
]
q
=
slice_
[
start
:
stop
]
k
=
slice_
[
start
+
single_size
:
stop
+
single_size
]
k
=
slice_
[
start
+
single_size
:
stop
+
single_size
]
v
=
slice_
[
start
+
2
*
single_size
:
stop
+
2
*
single_size
]
v
=
slice_
[
start
+
2
*
single_size
:
stop
+
2
*
single_size
]
weight
=
torch
.
cat
([
q
,
k
,
v
],
dim
=
0
)
weight
=
torch
.
cat
([
q
,
k
,
v
],
dim
=
0
)
weight
=
weight
.
to
(
device
=
self
.
device
)
weight
=
weight
.
to
(
device
=
self
.
device
)
weight
=
weight
.
to
(
dtype
=
self
.
dtype
)
weight
=
weight
.
to
(
dtype
=
self
.
dtype
)
return
weight
return
weight
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment