Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
74d3ce10
Unverified
Commit
74d3ce10
authored
Sep 24, 2024
by
Nicolas Patry
Committed by
GitHub
Sep 24, 2024
Browse files
Micro cleanup. (#2555)
parent
d31a6f75
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
2 additions
and
43 deletions
+2
-43
server/text_generation_server/utils/adapter.py
server/text_generation_server/utils/adapter.py
+2
-4
server/text_generation_server/utils/hub.py
server/text_generation_server/utils/hub.py
+0
-39
No files found.
server/text_generation_server/utils/adapter.py
View file @
74d3ce10
...
@@ -75,7 +75,6 @@ def load_and_merge_adapters(
...
@@ -75,7 +75,6 @@ def load_and_merge_adapters(
weight_names
:
Tuple
[
str
],
weight_names
:
Tuple
[
str
],
trust_remote_code
:
bool
=
False
,
trust_remote_code
:
bool
=
False
,
)
->
Tuple
[
"ModuleMap"
,
"AdapterConfig"
,
Set
[
str
],
PreTrainedTokenizer
]:
)
->
Tuple
[
"ModuleMap"
,
"AdapterConfig"
,
Set
[
str
],
PreTrainedTokenizer
]:
if
len
(
adapter_parameters
.
adapter_info
)
==
1
:
if
len
(
adapter_parameters
.
adapter_info
)
==
1
:
adapter
=
next
(
iter
(
adapter_parameters
.
adapter_info
))
adapter
=
next
(
iter
(
adapter_parameters
.
adapter_info
))
return
load_module_map
(
return
load_module_map
(
...
@@ -191,16 +190,15 @@ def load_module_map(
...
@@ -191,16 +190,15 @@ def load_module_map(
weight_names
:
Tuple
[
str
],
weight_names
:
Tuple
[
str
],
trust_remote_code
:
bool
=
False
,
trust_remote_code
:
bool
=
False
,
)
->
Tuple
[
"ModuleMap"
,
"AdapterConfig"
,
Set
[
str
],
PreTrainedTokenizer
]:
)
->
Tuple
[
"ModuleMap"
,
"AdapterConfig"
,
Set
[
str
],
PreTrainedTokenizer
]:
adapter_config
=
LoraConfig
.
load
(
adapter_path
or
adapter_id
,
None
)
adapter_config
=
LoraConfig
.
load
(
adapter_path
or
adapter_id
,
None
)
if
not
adapter_path
and
adapter_config
.
base_model_name_or_path
!=
model_id
:
if
not
adapter_path
and
adapter_config
.
base_model_name_or_path
!=
model_id
:
check_architectures
(
model_id
,
adapter_id
,
adapter_config
,
trust_remote_code
)
check_architectures
(
model_id
,
adapter_id
,
adapter_config
,
trust_remote_code
)
adapter_filenames
=
(
adapter_filenames
=
(
hub
.
_
adapter_
weight_files_from_dir
(
adapter_path
,
extension
=
".safetensors"
)
hub
.
_weight_files_from_dir
(
adapter_path
,
extension
=
".safetensors"
)
if
adapter_path
if
adapter_path
else
hub
.
_cached_
adapter_
weight_files
(
else
hub
.
_cached_weight_files
(
adapter_id
,
revision
=
revision
,
extension
=
".safetensors"
adapter_id
,
revision
=
revision
,
extension
=
".safetensors"
)
)
)
)
...
...
server/text_generation_server/utils/hub.py
View file @
74d3ce10
...
@@ -18,17 +18,6 @@ WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
...
@@ -18,17 +18,6 @@ WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
HF_HUB_OFFLINE
=
os
.
environ
.
get
(
"HF_HUB_OFFLINE"
,
"0"
).
lower
()
in
[
"true"
,
"1"
,
"yes"
]
HF_HUB_OFFLINE
=
os
.
environ
.
get
(
"HF_HUB_OFFLINE"
,
"0"
).
lower
()
in
[
"true"
,
"1"
,
"yes"
]
def
_cached_adapter_weight_files
(
adapter_id
:
str
,
revision
:
Optional
[
str
],
extension
:
str
)
->
List
[
str
]:
"""Guess weight files from the cached revision snapshot directory"""
d
=
_get_cached_revision_directory
(
adapter_id
,
revision
)
if
not
d
:
return
[]
filenames
=
_adapter_weight_files_from_dir
(
d
,
extension
)
return
filenames
def
_cached_weight_files
(
def
_cached_weight_files
(
model_id
:
str
,
revision
:
Optional
[
str
],
extension
:
str
model_id
:
str
,
revision
:
Optional
[
str
],
extension
:
str
)
->
List
[
str
]:
)
->
List
[
str
]:
...
@@ -65,39 +54,11 @@ def _weight_files_from_dir(d: Path, extension: str) -> List[str]:
...
@@ -65,39 +54,11 @@ def _weight_files_from_dir(d: Path, extension: str) -> List[str]:
if
f
.
endswith
(
extension
)
if
f
.
endswith
(
extension
)
and
"arguments"
not
in
f
and
"arguments"
not
in
f
and
"args"
not
in
f
and
"args"
not
in
f
and
"adapter"
not
in
f
and
"training"
not
in
f
and
"training"
not
in
f
]
]
return
filenames
return
filenames
def
_adapter_weight_files_from_dir
(
d
:
Path
,
extension
:
str
)
->
List
[
str
]:
# os.walk: do not iterate, just scan for depth 1, not recursively
# see _weight_files_from_dir, that's also what is done there
root
,
_
,
files
=
next
(
os
.
walk
(
str
(
d
)))
filenames
=
[
os
.
path
.
join
(
root
,
f
)
for
f
in
files
if
f
.
endswith
(
extension
)
and
"arguments"
not
in
f
and
"args"
not
in
f
and
"training"
not
in
f
]
return
filenames
def
_adapter_config_files_from_dir
(
d
:
Path
)
->
List
[
str
]:
# os.walk: do not iterate, just scan for depth 1, not recursively
# see _weight_files_from_dir, that's also what is done there
root
,
_
,
files
=
next
(
os
.
walk
(
str
(
d
)))
filenames
=
[
os
.
path
.
join
(
root
,
f
)
for
f
in
files
if
f
.
endswith
(
".json"
)
and
"arguments"
not
in
f
and
"args"
not
in
f
]
return
filenames
def
_get_cached_revision_directory
(
def
_get_cached_revision_directory
(
model_id
:
str
,
revision
:
Optional
[
str
]
model_id
:
str
,
revision
:
Optional
[
str
]
)
->
Optional
[
Path
]:
)
->
Optional
[
Path
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment