Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7025b11d
Unverified
Commit
7025b11d
authored
Aug 13, 2024
by
Cyrus Leung
Committed by
GitHub
Aug 13, 2024
Browse files
[Bugfix] Fix weight loading for Chameleon when TP>1 (#7410)
parent
5469146b
Changes
59
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
213 additions
and
123 deletions
+213
-123
tests/conftest.py
tests/conftest.py
+18
-8
tests/distributed/test_multimodal_broadcast.py
tests/distributed/test_multimodal_broadcast.py
+4
-0
tests/entrypoints/openai/test_oot_registration.py
tests/entrypoints/openai/test_oot_registration.py
+6
-2
tests/models/test_chameleon.py
tests/models/test_chameleon.py
+58
-33
tests/models/test_llava.py
tests/models/test_llava.py
+13
-8
tests/models/test_minicpmv.py
tests/models/test_minicpmv.py
+9
-24
tests/models/test_oot_registration.py
tests/models/test_oot_registration.py
+7
-2
vllm/distributed/communication_op.py
vllm/distributed/communication_op.py
+1
-1
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+1
-1
vllm/model_executor/layers/logits_processor.py
vllm/model_executor/layers/logits_processor.py
+8
-4
vllm/model_executor/model_loader/weight_utils.py
vllm/model_executor/model_loader/weight_utils.py
+25
-2
vllm/model_executor/models/arctic.py
vllm/model_executor/models/arctic.py
+5
-2
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+5
-2
vllm/model_executor/models/bart.py
vllm/model_executor/models/bart.py
+5
-2
vllm/model_executor/models/blip2.py
vllm/model_executor/models/blip2.py
+5
-2
vllm/model_executor/models/bloom.py
vllm/model_executor/models/bloom.py
+5
-2
vllm/model_executor/models/chameleon.py
vllm/model_executor/models/chameleon.py
+18
-5
vllm/model_executor/models/chatglm.py
vllm/model_executor/models/chatglm.py
+5
-2
vllm/model_executor/models/commandr.py
vllm/model_executor/models/commandr.py
+10
-19
vllm/model_executor/models/dbrx.py
vllm/model_executor/models/dbrx.py
+5
-2
No files found.
tests/conftest.py
View file @
7025b11d
...
@@ -4,7 +4,8 @@ import os
...
@@ -4,7 +4,8 @@ import os
import
sys
import
sys
from
collections
import
UserList
from
collections
import
UserList
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
TypedDict
,
TypeVar
,
Union
from
typing
import
(
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
TypedDict
,
TypeVar
,
Union
)
import
pytest
import
pytest
import
torch
import
torch
...
@@ -27,7 +28,7 @@ from vllm.logger import init_logger
...
@@ -27,7 +28,7 @@ from vllm.logger import init_logger
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.sequence
import
SampleLogprobs
from
vllm.sequence
import
SampleLogprobs
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
cuda_device_count_stateless
,
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
cuda_device_count_stateless
,
is_cpu
)
identity
,
is_cpu
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -197,6 +198,8 @@ class HfRunner:
...
@@ -197,6 +198,8 @@ class HfRunner:
is_embedding_model
:
bool
=
False
,
is_embedding_model
:
bool
=
False
,
is_vision_model
:
bool
=
False
,
is_vision_model
:
bool
=
False
,
is_encoder_decoder_model
:
bool
=
False
,
is_encoder_decoder_model
:
bool
=
False
,
postprocess_inputs
:
Callable
[[
BatchEncoding
],
BatchEncoding
]
=
identity
,
)
->
None
:
)
->
None
:
torch_dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
dtype
]
torch_dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
dtype
]
...
@@ -242,12 +245,14 @@ class HfRunner:
...
@@ -242,12 +245,14 @@ class HfRunner:
torch_dtype
=
torch_dtype
,
torch_dtype
=
torch_dtype
,
trust_remote_code
=
True
,
trust_remote_code
=
True
,
)
)
except
Exception
:
except
Exception
as
exc
:
logger
.
warning
(
logger
.
warning
(
"Unable to auto-load
processor from HuggingFace for
"
"Unable to auto-load
HuggingFace processor for model (%s).
"
"
model %s.
Using tokenizer instead."
,
model_name
)
"Using tokenizer instead.
Reason: %s
"
,
model_name
,
exc
)
self
.
processor
=
self
.
tokenizer
self
.
processor
=
self
.
tokenizer
self
.
postprocess_inputs
=
postprocess_inputs
def
generate
(
def
generate
(
self
,
self
,
prompts
:
List
[
str
],
prompts
:
List
[
str
],
...
@@ -267,6 +272,7 @@ class HfRunner:
...
@@ -267,6 +272,7 @@ class HfRunner:
processor_kwargs
[
"images"
]
=
images
[
i
]
processor_kwargs
[
"images"
]
=
images
[
i
]
inputs
=
self
.
processor
(
**
processor_kwargs
)
inputs
=
self
.
processor
(
**
processor_kwargs
)
inputs
=
self
.
postprocess_inputs
(
inputs
)
output_ids
=
self
.
model
.
generate
(
output_ids
=
self
.
model
.
generate
(
**
self
.
wrap_device
(
inputs
),
**
self
.
wrap_device
(
inputs
),
...
@@ -336,6 +342,7 @@ class HfRunner:
...
@@ -336,6 +342,7 @@ class HfRunner:
processor_kwargs
[
"images"
]
=
images
[
i
]
processor_kwargs
[
"images"
]
=
images
[
i
]
inputs
=
self
.
processor
(
**
processor_kwargs
)
inputs
=
self
.
processor
(
**
processor_kwargs
)
inputs
=
self
.
postprocess_inputs
(
inputs
)
output
=
self
.
model
.
generate
(
output
=
self
.
model
.
generate
(
**
self
.
wrap_device
(
inputs
),
**
self
.
wrap_device
(
inputs
),
...
@@ -420,6 +427,7 @@ class HfRunner:
...
@@ -420,6 +427,7 @@ class HfRunner:
processor_kwargs
[
"images"
]
=
images
[
i
]
processor_kwargs
[
"images"
]
=
images
[
i
]
inputs
=
self
.
processor
(
**
processor_kwargs
)
inputs
=
self
.
processor
(
**
processor_kwargs
)
inputs
=
self
.
postprocess_inputs
(
inputs
)
output
=
self
.
model
.
generate
(
output
=
self
.
model
.
generate
(
**
self
.
wrap_device
(
inputs
),
**
self
.
wrap_device
(
inputs
),
...
@@ -552,7 +560,8 @@ class VllmRunner:
...
@@ -552,7 +560,8 @@ class VllmRunner:
self
,
self
,
prompts
:
List
[
str
],
prompts
:
List
[
str
],
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
images
:
Optional
[
List
[
Image
.
Image
]]
=
None
,
images
:
Optional
[
Union
[
List
[
Image
.
Image
],
List
[
List
[
Image
.
Image
]]]]
=
None
,
)
->
List
[
Tuple
[
List
[
List
[
int
]],
List
[
str
]]]:
)
->
List
[
Tuple
[
List
[
List
[
int
]],
List
[
str
]]]:
if
images
is
not
None
:
if
images
is
not
None
:
assert
len
(
prompts
)
==
len
(
images
)
assert
len
(
prompts
)
==
len
(
images
)
...
@@ -587,7 +596,7 @@ class VllmRunner:
...
@@ -587,7 +596,7 @@ class VllmRunner:
for
req_output
in
req_outputs
:
for
req_output
in
req_outputs
:
for
sample
in
req_output
.
outputs
:
for
sample
in
req_output
.
outputs
:
output_str
=
sample
.
text
output_str
=
sample
.
text
output_ids
=
sample
.
token_ids
output_ids
=
list
(
sample
.
token_ids
)
output_logprobs
=
sample
.
logprobs
output_logprobs
=
sample
.
logprobs
outputs
.
append
((
output_ids
,
output_str
,
output_logprobs
))
outputs
.
append
((
output_ids
,
output_str
,
output_logprobs
))
return
outputs
return
outputs
...
@@ -596,7 +605,8 @@ class VllmRunner:
...
@@ -596,7 +605,8 @@ class VllmRunner:
self
,
self
,
prompts
:
List
[
str
],
prompts
:
List
[
str
],
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
images
:
Optional
[
List
[
Image
.
Image
]]
=
None
,
images
:
Optional
[
Union
[
List
[
Image
.
Image
],
List
[
List
[
Image
.
Image
]]]]
=
None
,
)
->
List
[
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]]]:
)
->
List
[
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]]]:
assert
sampling_params
.
logprobs
is
not
None
assert
sampling_params
.
logprobs
is
not
None
...
...
tests/distributed/test_multimodal_broadcast.py
View file @
7025b11d
...
@@ -18,8 +18,10 @@ from ..utils import fork_new_process_for_each_test
...
@@ -18,8 +18,10 @@ from ..utils import fork_new_process_for_each_test
@
pytest
.
mark
.
parametrize
(
"model, distributed_executor_backend"
,
[
@
pytest
.
mark
.
parametrize
(
"model, distributed_executor_backend"
,
[
(
"llava-hf/llava-1.5-7b-hf"
,
"ray"
),
(
"llava-hf/llava-1.5-7b-hf"
,
"ray"
),
(
"llava-hf/llava-v1.6-mistral-7b-hf"
,
"ray"
),
(
"llava-hf/llava-v1.6-mistral-7b-hf"
,
"ray"
),
(
"facebook/chameleon-7b"
,
"ray"
),
(
"llava-hf/llava-1.5-7b-hf"
,
"mp"
),
(
"llava-hf/llava-1.5-7b-hf"
,
"mp"
),
(
"llava-hf/llava-v1.6-mistral-7b-hf"
,
"mp"
),
(
"llava-hf/llava-v1.6-mistral-7b-hf"
,
"mp"
),
(
"facebook/chameleon-7b"
,
"mp"
),
])
])
@
fork_new_process_for_each_test
@
fork_new_process_for_each_test
def
test_models
(
hf_runner
,
vllm_runner
,
image_assets
,
model
:
str
,
def
test_models
(
hf_runner
,
vllm_runner
,
image_assets
,
model
:
str
,
...
@@ -34,6 +36,8 @@ def test_models(hf_runner, vllm_runner, image_assets, model: str,
...
@@ -34,6 +36,8 @@ def test_models(hf_runner, vllm_runner, image_assets, model: str,
from
..models.test_llava
import
models
,
run_test
from
..models.test_llava
import
models
,
run_test
elif
model
.
startswith
(
"llava-hf/llava-v1.6"
):
elif
model
.
startswith
(
"llava-hf/llava-v1.6"
):
from
..models.test_llava_next
import
models
,
run_test
from
..models.test_llava_next
import
models
,
run_test
elif
model
.
startswith
(
"facebook/chameleon"
):
from
..models.test_chameleon
import
models
,
run_test
else
:
else
:
raise
NotImplementedError
(
f
"Unsupported model:
{
model
}
"
)
raise
NotImplementedError
(
f
"Unsupported model:
{
model
}
"
)
...
...
tests/entrypoints/openai/test_oot_registration.py
View file @
7025b11d
import
sys
import
sys
import
time
import
time
from
typing
import
Optional
import
torch
import
torch
from
openai
import
OpenAI
,
OpenAIError
from
openai
import
OpenAI
,
OpenAIError
...
@@ -17,8 +18,11 @@ assert chatml_jinja_path.exists()
...
@@ -17,8 +18,11 @@ assert chatml_jinja_path.exists()
class
MyOPTForCausalLM
(
OPTForCausalLM
):
class
MyOPTForCausalLM
(
OPTForCausalLM
):
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
def
compute_logits
(
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
# this dummy model always predicts the first token
# this dummy model always predicts the first token
logits
=
super
().
compute_logits
(
hidden_states
,
sampling_metadata
)
logits
=
super
().
compute_logits
(
hidden_states
,
sampling_metadata
)
logits
.
zero_
()
logits
.
zero_
()
...
...
tests/models/test_chameleon.py
View file @
7025b11d
import
re
from
typing
import
List
,
Optional
,
Type
from
typing
import
List
,
Optional
,
Type
import
pytest
import
pytest
from
transformers
import
BatchEncoding
from
vllm.multimodal.utils
import
rescale_image_size
from
vllm.multimodal.utils
import
rescale_image_size
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
..conftest
import
IMAGE_ASSETS
,
VllmRunner
,
_ImageAssets
from
..conftest
import
IMAGE_ASSETS
,
HfRunner
,
VllmRunner
,
_ImageAssets
from
.utils
import
check_outputs_equal
pytestmark
=
pytest
.
mark
.
vlm
pytestmark
=
pytest
.
mark
.
vlm
...
@@ -19,9 +21,8 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
...
@@ -19,9 +21,8 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
models
=
[
"facebook/chameleon-7b"
]
models
=
[
"facebook/chameleon-7b"
]
#TODO (ywang96): Add correctness test when chameleon is
# available on transformers.
def
run_test
(
def
run_test
(
hf_runner
:
Type
[
HfRunner
],
vllm_runner
:
Type
[
VllmRunner
],
vllm_runner
:
Type
[
VllmRunner
],
image_assets
:
_ImageAssets
,
image_assets
:
_ImageAssets
,
model
:
str
,
model
:
str
,
...
@@ -29,13 +30,20 @@ def run_test(
...
@@ -29,13 +30,20 @@ def run_test(
size_factors
:
List
[
float
],
size_factors
:
List
[
float
],
dtype
:
str
,
dtype
:
str
,
max_tokens
:
int
,
max_tokens
:
int
,
num_logprobs
:
int
,
tensor_parallel_size
:
int
,
tensor_parallel_size
:
int
,
distributed_executor_backend
:
Optional
[
str
]
=
None
,
distributed_executor_backend
:
Optional
[
str
]
=
None
,
):
):
"""Test if the model can generate text given
"""Inference result should be the same between hf and vllm.
a batch of images and prompts.
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
"""
torch_dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
dtype
]
images
=
[
asset
.
pil_image
for
asset
in
image_assets
]
images
=
[
asset
.
pil_image
for
asset
in
image_assets
]
inputs_per_image
=
[(
inputs_per_image
=
[(
...
@@ -50,35 +58,49 @@ def run_test(
...
@@ -50,35 +58,49 @@ def run_test(
distributed_executor_backend
=
distributed_executor_backend
,
distributed_executor_backend
=
distributed_executor_backend
,
enforce_eager
=
True
)
as
vllm_model
:
enforce_eager
=
True
)
as
vllm_model
:
for
prompts
,
images
in
inputs_per_image
:
vllm_outputs_per_image
=
[
vllm_outputs
=
vllm_model
.
generate_greedy
(
prompts
,
vllm_model
.
generate_greedy_logprobs
(
prompts
,
max_tokens
,
max_tokens
,
images
=
images
)
num_logprobs
=
num_logprobs
,
for
i
in
range
(
len
(
vllm_outputs
)):
images
=
images
)
for
prompts
,
images
in
inputs_per_image
# format prompt back to original
]
replacements
=
{
"<racm3:break>"
:
""
,
def
process
(
hf_inputs
:
BatchEncoding
):
"<eoss>"
:
""
,
hf_inputs
[
"pixel_values"
]
=
hf_inputs
[
"pixel_values"
]
\
"<reserved08706>"
:
""
.
to
(
torch_dtype
)
# type: ignore
}
return
hf_inputs
pattern
=
'|'
.
join
(
replacements
.
keys
())
vllm_result
=
re
.
sub
(
with
hf_runner
(
model
,
pattern
,
dtype
=
dtype
,
lambda
match
:
replacements
[
match
.
group
(
0
)],
#noqa B023
postprocess_inputs
=
process
,
vllm_outputs
[
i
][
1
])
is_vision_model
=
True
)
as
hf_model
:
vllm_result
=
vllm_result
.
replace
(
"<image>"
,
""
,
1023
)
hf_outputs_per_image
=
[
assert
vllm_result
[:
len
(
prompts
[
i
])]
==
prompts
[
i
]
hf_model
.
generate_greedy_logprobs_limit
(
prompts
,
max_tokens
,
# assert at least 10 new characters are generated
num_logprobs
=
num_logprobs
,
# (to take stop token into account)
images
=
images
)
assert
len
(
vllm_outputs
[
i
][
1
])
-
len
(
prompts
[
i
])
>
10
for
prompts
,
images
in
inputs_per_image
]
for
hf_outputs
,
vllm_outputs
in
zip
(
hf_outputs_per_image
,
vllm_outputs_per_image
):
# HF Logprobs include image tokens, unlike vLLM, so we don't directly
# compare them
check_outputs_equal
(
outputs_0_lst
=
[
outputs
[:
2
]
for
outputs
in
hf_outputs
],
outputs_1_lst
=
[
outputs
[:
2
]
for
outputs
in
vllm_outputs
],
name_0
=
"hf"
,
name_1
=
"vllm"
,
)
@
pytest
.
mark
.
parametrize
(
"model"
,
models
)
@
pytest
.
mark
.
parametrize
(
"model"
,
models
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"size_factors"
,
"size_factors"
,
[
[
# No image
[],
# Single-scale
# Single-scale
[
1.0
],
[
1.0
],
# Single-scale, batched
# Single-scale, batched
...
@@ -88,15 +110,18 @@ def run_test(
...
@@ -88,15 +110,18 @@ def run_test(
],
],
)
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
8
])
def
test_models
(
vllm_runner
,
image_assets
,
model
,
size_factors
,
dtype
:
str
,
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
max_tokens
:
int
)
->
None
:
def
test_models
(
hf_runner
,
vllm_runner
,
image_assets
,
model
,
size_factors
,
dtype
,
max_tokens
,
num_logprobs
)
->
None
:
run_test
(
run_test
(
hf_runner
,
vllm_runner
,
vllm_runner
,
image_assets
,
image_assets
,
model
,
model
,
size_factors
=
size_factors
,
size_factors
=
size_factors
,
dtype
=
dtype
,
dtype
=
dtype
,
max_tokens
=
max_tokens
,
max_tokens
=
max_tokens
,
num_logprobs
=
num_logprobs
,
tensor_parallel_size
=
1
,
tensor_parallel_size
=
1
,
)
)
tests/models/test_llava.py
View file @
7025b11d
from
typing
import
List
,
Optional
,
Tuple
,
Type
from
typing
import
List
,
Optional
,
Tuple
,
Type
import
pytest
import
pytest
from
transformers
import
AutoConfig
,
AutoTokenizer
from
transformers
import
AutoConfig
,
AutoTokenizer
,
BatchEncoding
from
vllm.multimodal.utils
import
rescale_image_size
from
vllm.multimodal.utils
import
rescale_image_size
from
vllm.sequence
import
SampleLogprobs
from
vllm.sequence
import
SampleLogprobs
...
@@ -110,16 +110,21 @@ def run_test(
...
@@ -110,16 +110,21 @@ def run_test(
for
prompts
,
images
in
inputs_per_image
for
prompts
,
images
in
inputs_per_image
]
]
with
hf_runner
(
model
,
dtype
=
dtype
,
is_vision_model
=
True
)
as
hf_model
:
if
mantis_processor
is
not
None
:
if
mantis_processor
is
not
None
:
def
process
(
*
args
,
**
kwargs
):
def
process
(
hf_inputs
:
BatchEncoding
):
output
=
mantis_processor
(
*
args
,
**
kwargs
)
hf_inputs
[
"pixel_values"
]
=
hf_inputs
[
"pixel_values"
]
\
output
[
"pixel_values"
]
=
output
[
"pixel_values"
].
to
(
torch_dtype
)
.
to
(
torch_dtype
)
# type: ignore
return
output
return
hf_inputs
else
:
hf_model
.
processor
=
process
def
process
(
hf_inputs
:
BatchEncoding
):
return
hf_inputs
with
hf_runner
(
model
,
dtype
=
dtype
,
postprocess_inputs
=
process
,
is_vision_model
=
True
)
as
hf_model
:
hf_outputs_per_image
=
[
hf_outputs_per_image
=
[
hf_model
.
generate_greedy_logprobs_limit
(
prompts
,
hf_model
.
generate_greedy_logprobs_limit
(
prompts
,
max_tokens
,
max_tokens
,
...
...
tests/models/test_minicpmv.py
View file @
7025b11d
from
collections
import
UserDict
from
typing
import
List
,
Optional
,
Tuple
,
Type
from
typing
import
List
,
Optional
,
Tuple
,
Type
import
pytest
import
pytest
import
torch
import
torch
import
torch.types
import
torch.types
from
transformers
import
Batch
Feature
from
transformers
import
Batch
Encoding
from
vllm.multimodal.utils
import
rescale_image_size
from
vllm.multimodal.utils
import
rescale_image_size
from
vllm.sequence
import
SampleLogprobs
from
vllm.sequence
import
SampleLogprobs
...
@@ -14,18 +13,6 @@ from .utils import check_logprobs_close
...
@@ -14,18 +13,6 @@ from .utils import check_logprobs_close
pytestmark
=
pytest
.
mark
.
vlm
pytestmark
=
pytest
.
mark
.
vlm
class
NestedInputs
(
UserDict
):
def
__init__
(
self
,
model_inputs
:
BatchFeature
):
super
().
__init__
({
"model_inputs"
:
model_inputs
})
self
.
model_inputs
=
model_inputs
def
to
(
self
,
device
:
torch
.
types
.
Device
):
return
NestedInputs
(
self
.
model_inputs
.
to
(
device
))
# The image token is placed before "user" on purpose so that the test can pass
# The image token is placed before "user" on purpose so that the test can pass
HF_IMAGE_PROMPTS
=
IMAGE_ASSETS
.
prompts
({
HF_IMAGE_PROMPTS
=
IMAGE_ASSETS
.
prompts
({
"stop_sign"
:
"stop_sign"
:
...
@@ -41,6 +28,10 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
...
@@ -41,6 +28,10 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
models
=
[
"openbmb/MiniCPM-Llama3-V-2_5"
]
models
=
[
"openbmb/MiniCPM-Llama3-V-2_5"
]
def
_wrap_inputs
(
hf_inputs
:
BatchEncoding
)
->
BatchEncoding
:
return
BatchEncoding
({
"model_inputs"
:
hf_inputs
})
def
trunc_hf_output
(
hf_output
:
Tuple
[
List
[
int
],
str
,
def
trunc_hf_output
(
hf_output
:
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]]):
Optional
[
SampleLogprobs
]]):
output_ids
,
output_str
,
out_logprobs
=
hf_output
output_ids
,
output_str
,
out_logprobs
=
hf_output
...
@@ -105,11 +96,8 @@ def run_test(
...
@@ -105,11 +96,8 @@ def run_test(
for
prompts
,
images
in
inputs_per_image
for
prompts
,
images
in
inputs_per_image
]
]
with
hf_runner
(
model
,
dtype
=
dtype
)
as
hf_model
,
torch
.
no_grad
():
hf_model
=
hf_runner
(
model
,
dtype
=
dtype
,
postprocess_inputs
=
_wrap_inputs
)
hf_processor
=
hf_model
.
processor
with
hf_model
,
torch
.
no_grad
():
hf_model
.
processor
=
lambda
**
kw
:
NestedInputs
(
hf_processor
(
**
kw
)
# type: ignore
)
hf_outputs_per_image
=
[
hf_outputs_per_image
=
[
hf_model
.
generate_greedy_logprobs_limit
(
prompts
,
hf_model
.
generate_greedy_logprobs_limit
(
prompts
,
max_tokens
,
max_tokens
,
...
@@ -224,11 +212,8 @@ def run_multi_image_test(
...
@@ -224,11 +212,8 @@ def run_multi_image_test(
for
prompts
,
images
in
inputs_per_case
for
prompts
,
images
in
inputs_per_case
]
]
with
hf_runner
(
model
,
dtype
=
dtype
)
as
hf_model
,
torch
.
no_grad
():
hf_model
=
hf_runner
(
model
,
dtype
=
dtype
,
postprocess_inputs
=
_wrap_inputs
)
hf_processor
=
hf_model
.
processor
with
hf_model
,
torch
.
no_grad
():
hf_model
.
processor
=
lambda
**
kw
:
NestedInputs
(
hf_processor
(
**
kw
)
# type: ignore
)
hf_outputs_per_case
=
[
hf_outputs_per_case
=
[
hf_model
.
generate_greedy_logprobs_limit
(
prompts
,
hf_model
.
generate_greedy_logprobs_limit
(
prompts
,
max_tokens
,
max_tokens
,
...
...
tests/models/test_oot_registration.py
View file @
7025b11d
from
typing
import
Optional
import
torch
import
torch
from
vllm
import
LLM
,
ModelRegistry
,
SamplingParams
from
vllm
import
LLM
,
ModelRegistry
,
SamplingParams
...
@@ -7,8 +9,11 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
...
@@ -7,8 +9,11 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
class
MyOPTForCausalLM
(
OPTForCausalLM
):
class
MyOPTForCausalLM
(
OPTForCausalLM
):
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
def
compute_logits
(
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
# this dummy model always predicts the first token
# this dummy model always predicts the first token
logits
=
super
().
compute_logits
(
hidden_states
,
sampling_metadata
)
logits
=
super
().
compute_logits
(
hidden_states
,
sampling_metadata
)
logits
.
zero_
()
logits
.
zero_
()
...
...
vllm/distributed/communication_op.py
View file @
7025b11d
...
@@ -19,7 +19,7 @@ def tensor_model_parallel_all_gather(input_: torch.Tensor,
...
@@ -19,7 +19,7 @@ def tensor_model_parallel_all_gather(input_: torch.Tensor,
def
tensor_model_parallel_gather
(
input_
:
torch
.
Tensor
,
def
tensor_model_parallel_gather
(
input_
:
torch
.
Tensor
,
dst
:
int
=
0
,
dst
:
int
=
0
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
dim
:
int
=
-
1
)
->
Optional
[
torch
.
Tensor
]
:
"""Gather the input tensor across model parallel group."""
"""Gather the input tensor across model parallel group."""
return
get_tp_group
().
gather
(
input_
,
dst
,
dim
)
return
get_tp_group
().
gather
(
input_
,
dst
,
dim
)
...
...
vllm/distributed/parallel_state.py
View file @
7025b11d
...
@@ -329,7 +329,7 @@ class GroupCoordinator:
...
@@ -329,7 +329,7 @@ class GroupCoordinator:
def
gather
(
self
,
def
gather
(
self
,
input_
:
torch
.
Tensor
,
input_
:
torch
.
Tensor
,
dst
:
int
=
0
,
dst
:
int
=
0
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
dim
:
int
=
-
1
)
->
Optional
[
torch
.
Tensor
]
:
"""
"""
NOTE: We assume that the input tensor is on the same device across
NOTE: We assume that the input tensor is on the same device across
all the ranks.
all the ranks.
...
...
vllm/model_executor/layers/logits_processor.py
View file @
7025b11d
...
@@ -50,7 +50,7 @@ class LogitsProcessor(nn.Module):
...
@@ -50,7 +50,7 @@ class LogitsProcessor(nn.Module):
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
embedding_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
embedding_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
Optional
[
torch
.
Tensor
]
:
if
self
.
logits_as_input
:
if
self
.
logits_as_input
:
logits
=
hidden_states
logits
=
hidden_states
else
:
else
:
...
@@ -73,14 +73,18 @@ class LogitsProcessor(nn.Module):
...
@@ -73,14 +73,18 @@ class LogitsProcessor(nn.Module):
return
logits
return
logits
def
_get_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
def
_get_logits
(
lm_head
:
VocabParallelEmbedding
,
self
,
embedding_bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
hidden_states
:
torch
.
Tensor
,
lm_head
:
VocabParallelEmbedding
,
embedding_bias
:
Optional
[
torch
.
Tensor
],
)
->
Optional
[
torch
.
Tensor
]:
# Get the logits for the next tokens.
# Get the logits for the next tokens.
logits
=
lm_head
.
linear_method
.
apply
(
lm_head
,
logits
=
lm_head
.
linear_method
.
apply
(
lm_head
,
hidden_states
,
hidden_states
,
bias
=
embedding_bias
)
bias
=
embedding_bias
)
if
self
.
use_gather
:
if
self
.
use_gather
:
# None may be returned for rank > 0
logits
=
tensor_model_parallel_gather
(
logits
)
logits
=
tensor_model_parallel_gather
(
logits
)
else
:
else
:
# Gather is not supported for some devices such as TPUs.
# Gather is not supported for some devices such as TPUs.
...
...
vllm/model_executor/model_loader/weight_utils.py
View file @
7025b11d
...
@@ -19,6 +19,7 @@ from tqdm.auto import tqdm
...
@@ -19,6 +19,7 @@ from tqdm.auto import tqdm
from
transformers.utils
import
SAFE_WEIGHTS_INDEX_NAME
from
transformers.utils
import
SAFE_WEIGHTS_INDEX_NAME
from
vllm.config
import
LoadConfig
,
ModelConfig
from
vllm.config
import
LoadConfig
,
ModelConfig
from
vllm.distributed
import
get_tensor_model_parallel_rank
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization
import
(
QuantizationConfig
,
from
vllm.model_executor.layers.quantization
import
(
QuantizationConfig
,
get_quantization_config
)
get_quantization_config
)
...
@@ -514,8 +515,30 @@ def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
...
@@ -514,8 +515,30 @@ def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
def
default_weight_loader
(
param
:
torch
.
Tensor
,
def
default_weight_loader
(
param
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
)
->
None
:
loaded_weight
:
torch
.
Tensor
)
->
None
:
"""Default weight loader."""
"""Default weight loader."""
assert
param
.
size
()
==
loaded_weight
.
size
()
try
:
param
.
data
.
copy_
(
loaded_weight
)
assert
param
.
size
()
==
loaded_weight
.
size
(),
(
f
"Attempted to load weight (
{
loaded_weight
.
size
()
}
) "
f
"into parameter (
{
param
.
size
()
}
)"
)
param
.
data
.
copy_
(
loaded_weight
)
except
Exception
:
# NOTE: This exception is added for the purpose of setting breakpoint to
# debug weight loading issues.
raise
def
row_parallel_weight_loader
(
param
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
)
->
None
:
"""Load weights that are row-parallelized."""
tp_rank
=
get_tensor_model_parallel_rank
()
shard_dim
=
0
if
param
.
dim
()
!=
1
else
None
if
shard_dim
is
not
None
:
shard_size
=
param
.
data
.
shape
[
shard_dim
]
start_idx
=
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
shard_dim
,
start_idx
,
shard_size
)
return
default_weight_loader
(
param
,
loaded_weight
)
def
initialize_dummy_weights
(
def
initialize_dummy_weights
(
...
...
vllm/model_executor/models/arctic.py
View file @
7025b11d
...
@@ -433,8 +433,11 @@ class ArcticForCausalLM(nn.Module):
...
@@ -433,8 +433,11 @@ class ArcticForCausalLM(nn.Module):
attn_metadata
)
attn_metadata
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
def
compute_logits
(
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
sampling_metadata
)
return
logits
return
logits
...
...
vllm/model_executor/models/baichuan.py
View file @
7025b11d
...
@@ -346,8 +346,11 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
...
@@ -346,8 +346,11 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
attn_metadata
)
attn_metadata
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
def
compute_logits
(
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
sampling_metadata
)
return
logits
return
logits
...
...
vllm/model_executor/models/bart.py
View file @
7025b11d
...
@@ -872,8 +872,11 @@ class BartForConditionalGeneration(nn.Module):
...
@@ -872,8 +872,11 @@ class BartForConditionalGeneration(nn.Module):
return
self
.
model
(
input_ids
,
positions
,
encoder_input_ids
,
return
self
.
model
(
input_ids
,
positions
,
encoder_input_ids
,
encoder_positions
,
kv_caches
,
attn_metadata
)
encoder_positions
,
kv_caches
,
attn_metadata
)
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
def
compute_logits
(
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
sampling_metadata
)
return
logits
return
logits
...
...
vllm/model_executor/models/blip2.py
View file @
7025b11d
...
@@ -637,8 +637,11 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsVision):
...
@@ -637,8 +637,11 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsVision):
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
def
compute_logits
(
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
get_lm_head
(),
hidden_states
,
logits
=
self
.
logits_processor
(
self
.
get_lm_head
(),
hidden_states
,
sampling_metadata
)
sampling_metadata
)
return
logits
return
logits
...
...
vllm/model_executor/models/bloom.py
View file @
7025b11d
...
@@ -292,8 +292,11 @@ class BloomForCausalLM(nn.Module):
...
@@ -292,8 +292,11 @@ class BloomForCausalLM(nn.Module):
attn_metadata
)
attn_metadata
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
def
compute_logits
(
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
sampling_metadata
)
return
logits
return
logits
...
...
vllm/model_executor/models/chameleon.py
View file @
7025b11d
...
@@ -25,8 +25,10 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
...
@@ -25,8 +25,10 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
row_parallel_weight_loader
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.image
import
(
cached_get_tokenizer
,
from
vllm.multimodal.image
import
(
cached_get_tokenizer
,
repeat_and_pad_image_tokens
)
repeat_and_pad_image_tokens
)
...
@@ -141,6 +143,11 @@ class ChameleonLayerNorm(nn.LayerNorm):
...
@@ -141,6 +143,11 @@ class ChameleonLayerNorm(nn.LayerNorm):
super
().
__init__
(
hidden_size
,
*
args
,
**
kwargs
)
super
().
__init__
(
hidden_size
,
*
args
,
**
kwargs
)
self
.
normalized_shape
=
(
hidden_size
[
-
1
],
)
self
.
normalized_shape
=
(
hidden_size
[
-
1
],
)
set_weight_attrs
(
self
.
weight
,
{
"weight_loader"
:
row_parallel_weight_loader
})
set_weight_attrs
(
self
.
bias
,
{
"weight_loader"
:
row_parallel_weight_loader
})
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
):
hidden_states
=
F
.
layer_norm
(
hidden_states
,
hidden_states
=
F
.
layer_norm
(
hidden_states
,
self
.
normalized_shape
,
self
.
normalized_shape
,
...
@@ -697,6 +704,8 @@ class ChameleonVQVAEEncoder(nn.Module):
...
@@ -697,6 +704,8 @@ class ChameleonVQVAEEncoder(nn.Module):
)
)
def
forward
(
self
,
pixel_values
:
torch
.
Tensor
):
def
forward
(
self
,
pixel_values
:
torch
.
Tensor
):
pixel_values
=
pixel_values
.
to
(
self
.
conv_in
.
weight
.
dtype
)
# downsampling
# downsampling
hidden_states
=
[
self
.
conv_in
(
pixel_values
)]
hidden_states
=
[
self
.
conv_in
(
pixel_values
)]
for
i_level
in
range
(
self
.
num_resolutions
):
for
i_level
in
range
(
self
.
num_resolutions
):
...
@@ -959,15 +968,19 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsVision):
...
@@ -959,15 +968,19 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsVision):
attn_metadata
)
attn_metadata
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
def
compute_logits
(
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
sampling_metadata
)
# Disallow image tokens which does not include special
# Disallow image tokens which does not include special
# begin-image and end-image tokens
# begin-image and end-image tokens
image_tokens
=
self
.
model
.
vocabulary_mapping
.
image_tokens
if
logits
is
not
None
:
logits
[:,
image_tokens
]
=
torch
.
finfo
(
logits
.
dtype
).
min
image_tokens
=
self
.
model
.
vocabulary_mapping
.
image_tokens
logits
[:,
image_tokens
]
=
torch
.
finfo
(
logits
.
dtype
).
min
return
logits
return
logits
...
...
vllm/model_executor/models/chatglm.py
View file @
7025b11d
...
@@ -372,8 +372,11 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
...
@@ -372,8 +372,11 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
attn_metadata
)
attn_metadata
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
def
compute_logits
(
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
sampling_metadata
)
return
logits
return
logits
...
...
vllm/model_executor/models/commandr.py
View file @
7025b11d
...
@@ -25,13 +25,11 @@ from typing import Iterable, List, Optional, Set, Tuple
...
@@ -25,13 +25,11 @@ from typing import Iterable, List, Optional, Set, Tuple
import
torch
import
torch
import
torch.utils.checkpoint
import
torch.utils.checkpoint
from
torch
import
nn
from
torch
import
nn
from
torch.nn.parameter
import
Parameter
from
transformers
import
CohereConfig
from
transformers
import
CohereConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
from
vllm.distributed
import
get_tensor_model_parallel_world_size
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
...
@@ -43,7 +41,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
...
@@ -43,7 +41,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
row_parallel_weight_loader
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
...
@@ -67,25 +66,14 @@ class LayerNorm(nn.Module):
...
@@ -67,25 +66,14 @@ class LayerNorm(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
param_shape
))
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
param_shape
))
self
.
variance_epsilon
=
eps
self
.
variance_epsilon
=
eps
set_weight_attrs
(
self
.
weight
,
{
"weight_loader"
:
self
.
weight_loader
})
set_weight_attrs
(
self
.
weight
,
{
"weight_loader"
:
row_parallel_weight_loader
})
def
forward
(
self
,
hidden_states
,
residuals
=
None
):
def
forward
(
self
,
hidden_states
,
residuals
=
None
):
hidden_states
=
layer_norm_func
(
hidden_states
,
self
.
weight
,
hidden_states
=
layer_norm_func
(
hidden_states
,
self
.
weight
,
self
.
variance_epsilon
)
self
.
variance_epsilon
)
return
hidden_states
,
residuals
return
hidden_states
,
residuals
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
tp_rank
=
get_tensor_model_parallel_rank
()
shard_dim
=
0
if
param
.
dim
()
!=
1
else
None
param_data
=
param
.
data
if
shard_dim
is
not
None
:
shard_size
=
param_data
.
shape
[
shard_dim
]
start_idx
=
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
shard_dim
,
start_idx
,
shard_size
)
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
class
CohereMLP
(
nn
.
Module
):
class
CohereMLP
(
nn
.
Module
):
...
@@ -359,8 +347,11 @@ class CohereForCausalLM(nn.Module):
...
@@ -359,8 +347,11 @@ class CohereForCausalLM(nn.Module):
attn_metadata
)
attn_metadata
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
def
compute_logits
(
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
is_not_lora
=
hasattr
(
self
.
model
.
embed_tokens
,
'weight'
)
is_not_lora
=
hasattr
(
self
.
model
.
embed_tokens
,
'weight'
)
if
is_not_lora
:
if
is_not_lora
:
logits
=
self
.
logits_processor
(
self
.
model
.
embed_tokens
,
logits
=
self
.
logits_processor
(
self
.
model
.
embed_tokens
,
...
...
vllm/model_executor/models/dbrx.py
View file @
7025b11d
...
@@ -388,8 +388,11 @@ class DbrxForCausalLM(nn.Module):
...
@@ -388,8 +388,11 @@ class DbrxForCausalLM(nn.Module):
attn_metadata
)
attn_metadata
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
def
compute_logits
(
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
sampling_metadata
)
return
logits
return
logits
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment