Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9ba093b4
Unverified
Commit
9ba093b4
authored
Jun 05, 2024
by
Cyrus Leung
Committed by
GitHub
Jun 04, 2024
Browse files
[CI/Build] Simplify model loading for `HfRunner` (#5251)
parent
27208be6
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
18 additions
and
13 deletions
+18
-13
tests/conftest.py
tests/conftest.py
+16
-11
tests/models/test_embedding.py
tests/models/test_embedding.py
+1
-1
tests/models/test_llava.py
tests/models/test_llava.py
+1
-1
No files found.
tests/conftest.py
View file @
9ba093b4
import
contextlib
import
contextlib
import
gc
import
gc
import
os
import
os
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
TypeVar
import
pytest
import
pytest
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
PIL
import
Image
from
PIL
import
Image
from
transformers
import
(
AutoModelForCausalLM
,
Auto
Processor
,
AutoTokenizer
,
from
transformers
import
(
AutoModelForCausalLM
,
Auto
ModelForVision2Seq
,
LlavaConfig
,
LlavaForConditionalGeneration
)
AutoProcessor
,
AutoTokenizer
,
BatchEncoding
)
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
from
vllm.config
import
TokenizerPoolConfig
,
VisionLanguageConfig
from
vllm.config
import
TokenizerPoolConfig
,
VisionLanguageConfig
...
@@ -144,16 +145,12 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
...
@@ -144,16 +145,12 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
"float"
:
torch
.
float
,
"float"
:
torch
.
float
,
}
}
AutoModelForCausalLM
.
register
(
LlavaConfig
,
LlavaForConditionalGeneration
)
_T
=
TypeVar
(
"_T"
,
nn
.
Module
,
torch
.
Tensor
,
BatchEncoding
)
_EMBEDDING_MODELS
=
[
"intfloat/e5-mistral-7b-instruct"
,
]
class
HfRunner
:
class
HfRunner
:
def
wrap_device
(
self
,
input
:
any
)
:
def
wrap_device
(
self
,
input
:
_T
)
->
_T
:
if
not
is_cpu
():
if
not
is_cpu
():
return
input
.
to
(
"cuda"
)
return
input
.
to
(
"cuda"
)
else
:
else
:
...
@@ -163,13 +160,16 @@ class HfRunner:
...
@@ -163,13 +160,16 @@ class HfRunner:
self
,
self
,
model_name
:
str
,
model_name
:
str
,
dtype
:
str
=
"half"
,
dtype
:
str
=
"half"
,
*
,
is_embedding_model
:
bool
=
False
,
is_vision_model
:
bool
=
False
,
)
->
None
:
)
->
None
:
assert
dtype
in
_STR_DTYPE_TO_TORCH_DTYPE
assert
dtype
in
_STR_DTYPE_TO_TORCH_DTYPE
torch_dtype
=
_STR_DTYPE_TO_TORCH_DTYPE
[
dtype
]
torch_dtype
=
_STR_DTYPE_TO_TORCH_DTYPE
[
dtype
]
self
.
model_name
=
model_name
self
.
model_name
=
model_name
if
model_name
in
_EMBEDDING_MODELS
:
if
is_embedding_model
:
# Lazy init required for AMD CI
# Lazy init required for AMD CI
from
sentence_transformers
import
SentenceTransformer
from
sentence_transformers
import
SentenceTransformer
self
.
model
=
self
.
wrap_device
(
self
.
model
=
self
.
wrap_device
(
...
@@ -178,8 +178,13 @@ class HfRunner:
...
@@ -178,8 +178,13 @@ class HfRunner:
device
=
"cpu"
,
device
=
"cpu"
,
).
to
(
dtype
=
torch_dtype
))
).
to
(
dtype
=
torch_dtype
))
else
:
else
:
if
is_vision_model
:
auto_cls
=
AutoModelForVision2Seq
else
:
auto_cls
=
AutoModelForCausalLM
self
.
model
=
self
.
wrap_device
(
self
.
model
=
self
.
wrap_device
(
A
uto
ModelForCausalLM
.
from_pretrained
(
a
uto
_cls
.
from_pretrained
(
model_name
,
model_name
,
torch_dtype
=
torch_dtype
,
torch_dtype
=
torch_dtype
,
trust_remote_code
=
True
,
trust_remote_code
=
True
,
...
...
tests/models/test_embedding.py
View file @
9ba093b4
...
@@ -28,7 +28,7 @@ def test_models(
...
@@ -28,7 +28,7 @@ def test_models(
model
:
str
,
model
:
str
,
dtype
:
str
,
dtype
:
str
,
)
->
None
:
)
->
None
:
hf_model
=
hf_runner
(
model
,
dtype
=
dtype
)
hf_model
=
hf_runner
(
model
,
dtype
=
dtype
,
is_embedding_model
=
True
)
hf_outputs
=
hf_model
.
encode
(
example_prompts
)
hf_outputs
=
hf_model
.
encode
(
example_prompts
)
del
hf_model
del
hf_model
...
...
tests/models/test_llava.py
View file @
9ba093b4
...
@@ -94,7 +94,7 @@ def test_models(hf_runner, vllm_runner, hf_image_prompts, hf_images,
...
@@ -94,7 +94,7 @@ def test_models(hf_runner, vllm_runner, hf_image_prompts, hf_images,
"""
"""
model_id
,
vision_language_config
=
model_and_config
model_id
,
vision_language_config
=
model_and_config
hf_model
=
hf_runner
(
model_id
,
dtype
=
dtype
)
hf_model
=
hf_runner
(
model_id
,
dtype
=
dtype
,
is_vision_model
=
True
)
hf_outputs
=
hf_model
.
generate_greedy
(
hf_image_prompts
,
hf_outputs
=
hf_model
.
generate_greedy
(
hf_image_prompts
,
max_tokens
,
max_tokens
,
images
=
hf_images
)
images
=
hf_images
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment