Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
db4efbf4
Unverified
Commit
db4efbf4
authored
Jul 12, 2023
by
Nicolas Patry
Committed by
GitHub
Jul 12, 2023
Browse files
fix(server): T5 weights names. (#582)
Fixes #541
parent
f063ebde
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
7 deletions
+11
-7
server/text_generation_server/models/custom_modeling/t5_modeling.py
...t_generation_server/models/custom_modeling/t5_modeling.py
+1
-6
server/text_generation_server/models/t5.py
server/text_generation_server/models/t5.py
+10
-1
No files found.
server/text_generation_server/models/custom_modeling/t5_modeling.py
View file @
db4efbf4
...
@@ -1006,12 +1006,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
...
@@ -1006,12 +1006,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
model_dim
=
config
.
d_model
self
.
model_dim
=
config
.
d_model
try
:
self
.
shared
=
TensorParallelEmbedding
(
prefix
=
"shared"
,
weights
=
weights
)
self
.
shared
=
TensorParallelEmbedding
(
prefix
=
"shared"
,
weights
=
weights
)
except
RuntimeError
:
self
.
shared
=
TensorParallelEmbedding
(
prefix
=
"encoder.embed_tokens"
,
weights
=
weights
)
encoder_config
=
copy
.
deepcopy
(
config
)
encoder_config
=
copy
.
deepcopy
(
config
)
encoder_config
.
is_decoder
=
False
encoder_config
.
is_decoder
=
False
...
...
server/text_generation_server/models/t5.py
View file @
db4efbf4
...
@@ -55,7 +55,16 @@ class T5Sharded(Seq2SeqLM):
...
@@ -55,7 +55,16 @@ class T5Sharded(Seq2SeqLM):
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
filenames
=
weight_files
(
model_id
,
revision
=
revision
,
extension
=
".safetensors"
)
filenames
=
weight_files
(
model_id
,
revision
=
revision
,
extension
=
".safetensors"
)
weights
=
Weights
(
weights
=
Weights
(
filenames
,
device
=
device
,
dtype
=
dtype
,
process_group
=
self
.
process_group
filenames
,
device
=
device
,
dtype
=
dtype
,
process_group
=
self
.
process_group
,
aliases
=
{
"shared.weight"
:
[
"encoder.embed_tokens.weight"
,
"decoder.embed_tokens.weight"
,
]
},
)
)
model
=
T5ForConditionalGeneration
(
config
,
weights
)
model
=
T5ForConditionalGeneration
(
config
,
weights
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment