Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f6a59309
Unverified
Commit
f6a59309
authored
May 09, 2024
by
SangBin Cho
Committed by
GitHub
May 08, 2024
Browse files
[CI] Make mistral tests pass (#4596)
parent
d7740ea4
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
85 additions
and
19 deletions
+85
-19
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+1
-1
tests/conftest.py
tests/conftest.py
+62
-0
tests/models/test_big_models.py
tests/models/test_big_models.py
+1
-1
tests/models/test_mistral.py
tests/models/test_mistral.py
+18
-15
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+3
-2
No files found.
.buildkite/test-pipeline.yaml
View file @
f6a59309
...
...
@@ -76,7 +76,7 @@ steps:
#mirror_hardwares: [amd]
commands
:
-
bash ../.buildkite/download-images.sh
-
pytest -v -s models --ignore=models/test_llava.py
--ignore=models/test_mistral.py
-
pytest -v -s models --ignore=models/test_llava.py
-
label
:
Llava Test
#mirror_hardwares: [amd]
...
...
tests/conftest.py
View file @
f6a59309
...
...
@@ -272,6 +272,68 @@ class HfRunner:
all_logprobs
.
append
(
seq_logprobs
)
return
all_logprobs
def
generate_greedy_logprobs_limit
(
self
,
prompts
:
List
[
str
],
max_tokens
:
int
,
num_logprobs
:
int
,
)
->
List
[
Tuple
[
List
[
int
],
str
]]:
all_logprobs
=
[]
all_output_ids
=
[]
all_output_strs
=
[]
for
prompt
in
prompts
:
input_ids
=
self
.
tokenizer
(
prompt
,
return_tensors
=
"pt"
).
input_ids
output
=
self
.
model
.
generate
(
input_ids
.
cuda
(),
use_cache
=
True
,
do_sample
=
False
,
max_new_tokens
=
max_tokens
,
output_hidden_states
=
True
,
return_dict_in_generate
=
True
,
)
seq_logprobs
=
[]
for
_
,
hidden_states
in
enumerate
(
output
.
hidden_states
):
last_hidden_states
=
hidden_states
[
-
1
][
0
]
logits
=
torch
.
matmul
(
last_hidden_states
,
self
.
model
.
get_output_embeddings
().
weight
.
t
(),
)
if
getattr
(
self
.
model
.
get_output_embeddings
(),
"bias"
,
None
)
is
not
None
:
logits
+=
self
.
model
.
get_output_embeddings
(
).
bias
.
unsqueeze
(
0
)
logprobs
=
torch
.
nn
.
functional
.
log_softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float32
)
seq_logprobs
.
append
(
logprobs
)
# convert to dict
seq_logprobs_lst
=
[]
for
tok_idx
,
tok_logprobs
in
enumerate
(
seq_logprobs
):
# drop prompt logprobs
if
tok_idx
==
0
:
tok_logprobs
=
tok_logprobs
[
-
1
,
:].
reshape
(
1
,
-
1
)
topk
=
tok_logprobs
.
topk
(
num_logprobs
)
tok_logprobs_dct
=
{}
for
token_id
,
logprob
in
zip
(
topk
.
indices
[
0
],
topk
.
values
[
0
]):
tok_logprobs_dct
[
token_id
.
item
()]
=
logprob
.
item
()
seq_logprobs_lst
.
append
(
tok_logprobs_dct
)
all_logprobs
.
append
(
seq_logprobs_lst
)
seq_ids
=
output
.
sequences
[
0
]
output_len
=
seq_ids
.
shape
[
0
]
-
input_ids
.
shape
[
1
]
output_ids
=
seq_ids
[
-
output_len
:]
all_output_ids
.
append
(
output_ids
.
tolist
())
all_output_strs
.
append
(
self
.
tokenizer
.
decode
(
output_ids
))
outputs
=
zip
(
all_output_ids
,
all_output_strs
,
all_logprobs
)
return
[(
output_ids
,
output_str
,
output_logprobs
)
for
output_ids
,
output_str
,
output_logprobs
in
outputs
]
def
__del__
(
self
):
del
self
.
model
cleanup
()
...
...
tests/models/test_big_models.py
View file @
f6a59309
...
...
@@ -8,7 +8,7 @@ import pytest
MODELS
=
[
"meta-llama/Llama-2-7b-hf"
,
# "mistralai/Mistral-7B-v0.1", #
Broken
# "mistralai/Mistral-7B-v0.1", #
Tested by test_mistral.py
# "Deci/DeciLM-7b", # Broken
# "tiiuae/falcon-7b", # Broken
"EleutherAI/gpt-j-6b"
,
...
...
tests/models/test_mistral.py
View file @
f6a59309
...
...
@@ -4,6 +4,8 @@ Run `pytest tests/models/test_mistral.py`.
"""
import
pytest
from
tests.models.utils
import
check_logprobs_close
MODELS
=
[
"mistralai/Mistral-7B-Instruct-v0.1"
,
]
...
...
@@ -11,30 +13,31 @@ MODELS = [
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
@
pytest
.
mark
.
skip
(
"Two problems: 1. Failing correctness tests. 2. RuntimeError: expected "
"scalar type BFloat16 but found Half (only in CI)."
)
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
def
test_models
(
hf_runner
,
vllm_runner
,
example_
long_
prompts
,
example_prompts
,
model
:
str
,
dtype
:
str
,
max_tokens
:
int
,
num_logprobs
:
int
,
)
->
None
:
# TODO(sang): Sliding window should be tested separately.
hf_model
=
hf_runner
(
model
,
dtype
=
dtype
)
hf_outputs
=
hf_model
.
generate_greedy
(
example_long_prompts
,
max_tokens
)
hf_outputs
=
hf_model
.
generate_greedy_logprobs_limit
(
example_prompts
,
max_tokens
,
num_logprobs
)
del
hf_model
vllm_model
=
vllm_runner
(
model
,
dtype
=
dtype
)
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_long_prompts
,
max_tokens
)
vllm_outputs
=
vllm_model
.
generate_greedy_logprobs
(
example_prompts
,
max_tokens
,
num_logprobs
)
del
vllm_model
for
i
in
range
(
len
(
example_long_prompts
)):
hf_output_ids
,
hf_output_str
=
hf_outputs
[
i
]
vllm_output_ids
,
vllm_output_str
=
vllm_outputs
[
i
]
assert
hf_output_str
==
vllm_output_str
,
(
f
"Test
{
i
}
:
\n
HF:
{
hf_output_str
!
r
}
\n
vLLM:
{
vllm_output_str
!
r
}
"
)
assert
hf_output_ids
==
vllm_output_ids
,
(
f
"Test
{
i
}
:
\n
HF:
{
hf_output_ids
}
\n
vLLM:
{
vllm_output_ids
}
"
)
check_logprobs_close
(
outputs_0_lst
=
hf_outputs
,
outputs_1_lst
=
vllm_outputs
,
name_0
=
"hf"
,
name_1
=
"vllm"
,
)
vllm/model_executor/layers/rotary_embedding.py
View file @
f6a59309
...
...
@@ -109,7 +109,7 @@ class RotaryEmbedding(nn.Module):
key_pass
=
key
[...,
self
.
rotary_dim
:]
self
.
cos_sin_cache
:
torch
.
Tensor
=
self
.
cos_sin_cache
.
to
(
positions
.
device
)
positions
.
device
,
dtype
=
query
.
dtype
)
cos_sin
=
self
.
cos_sin_cache
[
torch
.
add
(
positions
,
offsets
)
if
offsets
is
not
None
else
positions
]
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
...
...
@@ -143,7 +143,8 @@ class RotaryEmbedding(nn.Module):
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
self
.
cos_sin_cache
=
self
.
cos_sin_cache
.
to
(
positions
.
device
)
self
.
cos_sin_cache
=
self
.
cos_sin_cache
.
to
(
positions
.
device
,
dtype
=
query
.
dtype
)
# ops.rotary_embedding()/batched_rotary_embedding()
# are in-place operations that update the query and key tensors.
if
offsets
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment