Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
95d35469
Unverified
Commit
95d35469
authored
Jun 01, 2023
by
OlivierDehaene
Committed by
GitHub
Jun 01, 2023
Browse files
feat(server): load santacoder/starcoder models with safetensors (#393)
Fix #366
parent
c0928e6f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
91 additions
and
91 deletions
+91
-91
launcher/src/main.rs
launcher/src/main.rs
+2
-14
server/text_generation_server/models/flash_santacoder.py
server/text_generation_server/models/flash_santacoder.py
+89
-77
No files found.
launcher/src/main.rs
View file @
95d35469
...
@@ -546,11 +546,7 @@ enum LauncherError {
...
@@ -546,11 +546,7 @@ enum LauncherError {
WebserverCannotStart
,
WebserverCannotStart
,
}
}
fn
download_convert_model
(
fn
download_convert_model
(
args
:
&
Args
,
running
:
Arc
<
AtomicBool
>
)
->
Result
<
(),
LauncherError
>
{
args
:
&
Args
,
auto_convert
:
bool
,
running
:
Arc
<
AtomicBool
>
,
)
->
Result
<
(),
LauncherError
>
{
let
mut
download_argv
=
vec!
[
let
mut
download_argv
=
vec!
[
"text-generation-server"
.to_string
(),
"text-generation-server"
.to_string
(),
"download-weights"
.to_string
(),
"download-weights"
.to_string
(),
...
@@ -562,11 +558,6 @@ fn download_convert_model(
...
@@ -562,11 +558,6 @@ fn download_convert_model(
"--json-output"
.to_string
(),
"--json-output"
.to_string
(),
];
];
// Auto convert weights to safetensors
if
auto_convert
{
download_argv
.push
(
"--auto-convert"
.to_string
());
}
// Model optional revision
// Model optional revision
if
let
Some
(
revision
)
=
&
args
.revision
{
if
let
Some
(
revision
)
=
&
args
.revision
{
download_argv
.push
(
"--revision"
.to_string
());
download_argv
.push
(
"--revision"
.to_string
());
...
@@ -932,11 +923,8 @@ fn main() -> Result<(), LauncherError> {
...
@@ -932,11 +923,8 @@ fn main() -> Result<(), LauncherError> {
})
})
.expect
(
"Error setting Ctrl-C handler"
);
.expect
(
"Error setting Ctrl-C handler"
);
// auto_convert is only needed for sharded models as we do not require safetensors in
// single shard mode
let
auto_convert
=
num_shard
>
1
;
// Download and convert model weights
// Download and convert model weights
download_convert_model
(
&
args
,
auto_convert
,
running
.clone
())
?
;
download_convert_model
(
&
args
,
running
.clone
())
?
;
// Shared shutdown bool
// Shared shutdown bool
let
shutdown
=
Arc
::
new
(
Mutex
::
new
(
false
));
let
shutdown
=
Arc
::
new
(
Mutex
::
new
(
false
));
...
...
server/text_generation_server/models/flash_santacoder.py
View file @
95d35469
...
@@ -54,12 +54,7 @@ class FlashSantacoder(FlashCausalLM):
...
@@ -54,12 +54,7 @@ class FlashSantacoder(FlashCausalLM):
)
)
# We do not use from_pretrained as we modified the model internal module layout
# We do not use from_pretrained as we modified the model internal module layout
try
:
filenames
=
weight_files
(
model_id
,
revision
,
".safetensors"
)
filenames
=
weight_files
(
model_id
,
revision
,
".bin"
)
# Local files not found
except
LocalEntryNotFoundError
:
hub_files
=
weight_hub_files
(
model_id
,
revision
,
".bin"
)
filenames
=
download_weights
(
hub_files
,
model_id
,
revision
)
with
init_empty_weights
():
with
init_empty_weights
():
model
=
FlashSantacoderForCausalLM
(
config
)
model
=
FlashSantacoderForCausalLM
(
config
)
...
@@ -91,85 +86,100 @@ class FlashSantacoder(FlashCausalLM):
...
@@ -91,85 +86,100 @@ class FlashSantacoder(FlashCausalLM):
transpose
:
bool
,
transpose
:
bool
,
):
):
for
filename
in
filenames
:
for
filename
in
filenames
:
state_dict
=
torch
.
load
(
filename
,
map_location
=
"cpu"
)
with
safe_open
(
for
key
,
value
in
state_dict
.
items
():
filename
,
framework
=
"pt"
,
device
=
str
(
device
)
if
quantize
is
None
else
"cpu"
value
=
value
.
to
(
device
if
quantize
is
None
else
"cpu"
).
to
(
dtype
)
)
as
f
:
for
key
in
f
.
keys
():
layer_name
=
"."
.
join
(
key
.
split
(
"."
)[:
4
])
value
=
f
.
get_tensor
(
key
)
value
=
value
.
to
(
device
if
quantize
is
None
else
"cpu"
).
to
(
dtype
)
# Fused qkv
if
"q_attn.weight"
in
key
or
"kv_attn.weight"
in
key
:
layer_name
=
"."
.
join
(
key
.
split
(
"."
)[:
4
])
final_key
=
layer_name
+
".c_attn.weight"
elif
"q_attn.bias"
in
key
or
"kv_attn.bias"
in
key
:
# Fused qkv
final_key
=
layer_name
+
".c_attn.bias"
if
"q_attn.weight"
in
key
or
"kv_attn.weight"
in
key
:
final_key
=
layer_name
+
".c_attn.weight"
else
:
elif
"q_attn.bias"
in
key
or
"kv_attn.bias"
in
key
:
final_key
=
key
final_key
=
layer_name
+
".c_attn.bias"
module_name
,
param_name
=
final_key
.
rsplit
(
"."
,
1
)
else
:
module
=
model
.
get_submodule
(
module_name
)
final_key
=
key
try
:
module_name
,
param_name
=
final_key
.
rsplit
(
"."
,
1
)
current_parameter_tensor
=
module
.
_parameters
[
param_name
]
module
=
model
.
get_submodule
(
module_name
)
except
KeyError
:
current_parameter_tensor
=
None
try
:
current_parameter_tensor
=
module
.
_parameters
[
param_name
]
if
current_parameter_tensor
is
not
None
:
except
KeyError
:
if
transpose
and
(
current_parameter_tensor
=
None
"c_fc.weight"
in
key
or
"c_proj.weight"
in
key
if
current_parameter_tensor
is
not
None
:
or
"q_attn.weight"
in
key
if
transpose
and
(
or
"kv_attn.weight"
in
key
"c_fc.weight"
in
key
or
"c_attn.weight"
in
key
or
"c_proj.weight"
in
key
):
or
"q_attn.weight"
in
key
# Tranpose as we use nn.Linear instead of Conv1D
or
"kv_attn.weight"
in
key
value
=
value
.
T
or
"c_attn.weight"
in
key
):
if
current_parameter_tensor
.
device
==
torch
.
device
(
"meta"
):
# Tranpose as we use nn.Linear instead of Conv1D
# Init qkv
value
=
value
.
T
if
"c_attn.weight"
in
final_key
:
module
.
_parameters
[
param_name
]
=
value
.
new_empty
(
if
current_parameter_tensor
.
device
==
torch
.
device
(
"meta"
):
(
# Init qkv
model
.
transformer
.
head_size
if
"c_attn.weight"
in
final_key
:
*
(
model
.
transformer
.
num_heads
+
2
),
module
.
_parameters
[
param_name
]
=
value
.
new_empty
(
value
.
shape
[
1
],
(
model
.
transformer
.
head_size
*
(
model
.
transformer
.
num_heads
+
2
),
value
.
shape
[
1
],
)
)
)
)
elif
"c_attn.bias"
in
final_key
:
elif
"c_attn.bias"
in
final_key
:
module
.
_parameters
[
param_name
]
=
value
.
new_empty
(
module
.
_parameters
[
param_name
]
=
value
.
new_empty
(
(
(
model
.
transformer
.
head_size
model
.
transformer
.
head_size
*
(
model
.
transformer
.
num_heads
+
2
)
*
(
model
.
transformer
.
num_heads
+
2
)
)
)
)
)
# Copy to correct slice
# Copy to correct slice
if
"q_attn.weight"
in
key
:
if
"q_attn.weight"
in
key
:
module
.
_parameters
[
param_name
][:
value
.
shape
[
0
]]
=
value
module
.
_parameters
[
param_name
][:
value
.
shape
[
0
]]
=
value
elif
"q_attn.bias"
in
key
:
elif
"q_attn.bias"
in
key
:
module
.
_parameters
[
param_name
][:
value
.
shape
[
0
]]
=
value
module
.
_parameters
[
param_name
][:
value
.
shape
[
0
]]
=
value
elif
"kv_attn.weight"
in
key
:
elif
"kv_attn.weight"
in
key
:
module
.
_parameters
[
param_name
][
module
.
_parameters
[
param_name
][
model
.
transformer
.
head_size
*
model
.
transformer
.
num_heads
:
model
.
transformer
.
head_size
*
model
.
transformer
.
num_heads
:
]
=
value
]
=
value
elif
"kv_attn.bias"
in
key
:
elif
"kv_attn.bias"
in
key
:
module
.
_parameters
[
param_name
][
module
.
_parameters
[
param_name
][
model
.
transformer
.
head_size
*
model
.
transformer
.
num_heads
:
model
.
transformer
.
head_size
*
model
.
transformer
.
num_heads
:
]
=
value
]
=
value
else
:
if
current_parameter_tensor
.
shape
!=
value
.
shape
:
raise
ValueError
(
f
"Name
{
final_key
}
-- Current
{
current_parameter_tensor
.
shape
}
and got
{
value
.
shape
}
"
)
module
.
_parameters
[
param_name
]
=
value
else
:
else
:
if
current_parameter_tensor
.
shape
!=
value
.
shape
:
module
.
_buffers
[
param_name
]
=
value
raise
ValueError
(
f
"Name
{
final_key
}
-- Current
{
current_parameter_tensor
.
shape
}
and got
{
value
.
shape
}
"
)
module
.
_parameters
[
param_name
]
=
value
else
:
module
.
_buffers
[
param_name
]
=
value
del
value
del
value
if
model
.
lm_head
.
weight
.
device
==
torch
.
device
(
"meta"
):
model
.
lm_head
.
weight
=
torch
.
nn
.
Parameter
(
model
.
transformer
.
wte
.
weight
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
model
.
post_load_weights
(
quantize
)
model
.
post_load_weights
(
quantize
)
uninitialized_parameters
=
[]
for
n
,
p
in
model
.
named_parameters
():
if
p
.
data
.
device
==
torch
.
device
(
"meta"
):
uninitialized_parameters
.
append
(
n
)
if
uninitialized_parameters
:
raise
RuntimeError
(
f
"found uninitialized parameters in model :
{
uninitialized_parameters
}
"
)
def
decode
(
self
,
generated_ids
:
List
[
int
])
->
str
:
def
decode
(
self
,
generated_ids
:
List
[
int
])
->
str
:
# Do not skip special tokens as they are used for custom parsing rules of the generated text
# Do not skip special tokens as they are used for custom parsing rules of the generated text
return
self
.
tokenizer
.
decode
(
return
self
.
tokenizer
.
decode
(
...
@@ -389,6 +399,8 @@ class FlashSantacoderSharded(FlashSantacoder):
...
@@ -389,6 +399,8 @@ class FlashSantacoderSharded(FlashSantacoder):
else
:
else
:
module
.
_buffers
[
param_name
]
=
tensor
module
.
_buffers
[
param_name
]
=
tensor
model
.
lm_head
.
weight
=
torch
.
nn
.
Parameter
(
model
.
transformer
.
wte
.
weight
)
if
model
.
lm_head
.
weight
.
device
==
torch
.
device
(
"meta"
):
model
.
lm_head
.
weight
=
torch
.
nn
.
Parameter
(
model
.
transformer
.
wte
.
weight
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
model
.
post_load_weights
(
quantize
)
model
.
post_load_weights
(
quantize
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment