Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
59a0192f
Unverified
Commit
59a0192f
authored
Jan 20, 2025
by
Cyrus Leung
Committed by
GitHub
Jan 20, 2025
Browse files
[Core] Interface for accessing model from `VllmRunner` (#10353)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
83609791
Changes
35
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
403 additions
and
292 deletions
+403
-292
tests/conftest.py
tests/conftest.py
+5
-0
tests/engine/test_custom_executor.py
tests/engine/test_custom_executor.py
+3
-1
tests/model_executor/test_model_load_with_params.py
tests/model_executor/test_model_load_with_params.py
+33
-31
tests/models/decoder_only/language/test_jamba.py
tests/models/decoder_only/language/test_jamba.py
+5
-2
tests/models/decoder_only/language/test_mamba.py
tests/models/decoder_only/language/test_mamba.py
+5
-2
tests/models/decoder_only/language/test_models.py
tests/models/decoder_only/language/test_models.py
+5
-2
tests/models/decoder_only/vision_language/test_qwen2_vl.py
tests/models/decoder_only/vision_language/test_qwen2_vl.py
+26
-23
tests/models/embedding/language/test_cls_models.py
tests/models/embedding/language/test_cls_models.py
+5
-2
tests/models/embedding/language/test_embedding.py
tests/models/embedding/language/test_embedding.py
+5
-2
tests/quantization/test_compressed_tensors.py
tests/quantization/test_compressed_tensors.py
+137
-105
tests/quantization/test_fp8.py
tests/quantization/test_fp8.py
+29
-23
tests/quantization/test_lm_head.py
tests/quantization/test_lm_head.py
+20
-17
tests/quantization/test_quark.py
tests/quantization/test_quark.py
+13
-10
tests/tensorizer_loader/test_tensorizer.py
tests/tensorizer_loader/test_tensorizer.py
+16
-18
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+3
-14
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+36
-16
vllm/executor/executor_base.py
vllm/executor/executor_base.py
+41
-9
vllm/executor/mp_distributed_executor.py
vllm/executor/mp_distributed_executor.py
+1
-1
vllm/model_executor/model_loader/tensorizer.py
vllm/model_executor/model_loader/tensorizer.py
+4
-13
vllm/spec_decode/ngram_worker.py
vllm/spec_decode/ngram_worker.py
+11
-1
No files found.
tests/conftest.py
View file @
59a0192f
...
@@ -244,6 +244,7 @@ def video_assets() -> _VideoAssets:
...
@@ -244,6 +244,7 @@ def video_assets() -> _VideoAssets:
_T
=
TypeVar
(
"_T"
,
nn
.
Module
,
torch
.
Tensor
,
BatchEncoding
,
BatchFeature
,
dict
)
_T
=
TypeVar
(
"_T"
,
nn
.
Module
,
torch
.
Tensor
,
BatchEncoding
,
BatchFeature
,
dict
)
_R
=
TypeVar
(
"_R"
)
class
HfRunner
:
class
HfRunner
:
...
@@ -930,6 +931,10 @@ class VllmRunner:
...
@@ -930,6 +931,10 @@ class VllmRunner:
req_outputs
=
self
.
model
.
score
(
text_1
,
text_2
)
req_outputs
=
self
.
model
.
score
(
text_1
,
text_2
)
return
[
req_output
.
outputs
.
score
for
req_output
in
req_outputs
]
return
[
req_output
.
outputs
.
score
for
req_output
in
req_outputs
]
def
apply_model
(
self
,
func
:
Callable
[[
nn
.
Module
],
_R
])
->
list
[
_R
]:
executor
=
self
.
model
.
llm_engine
.
model_executor
return
executor
.
apply_model
(
func
)
def
__enter__
(
self
):
def
__enter__
(
self
):
return
self
return
self
...
...
tests/engine/test_custom_executor.py
View file @
59a0192f
...
@@ -51,7 +51,9 @@ def test_custom_executor(model, tmp_path):
...
@@ -51,7 +51,9 @@ def test_custom_executor(model, tmp_path):
assert
not
os
.
path
.
exists
(
".marker"
)
assert
not
os
.
path
.
exists
(
".marker"
)
engine_args
=
EngineArgs
(
engine_args
=
EngineArgs
(
model
=
model
,
distributed_executor_backend
=
CustomUniExecutor
)
model
=
model
,
distributed_executor_backend
=
CustomUniExecutor
,
)
engine
=
LLMEngine
.
from_engine_args
(
engine_args
)
engine
=
LLMEngine
.
from_engine_args
(
engine_args
)
sampling_params
=
SamplingParams
(
max_tokens
=
1
)
sampling_params
=
SamplingParams
(
max_tokens
=
1
)
...
...
tests/model_executor/test_model_load_with_params.py
View file @
59a0192f
...
@@ -25,13 +25,12 @@ def test_model_loading_with_params(vllm_runner):
...
@@ -25,13 +25,12 @@ def test_model_loading_with_params(vllm_runner):
with
vllm_runner
(
model_name
=
MODEL_NAME
,
with
vllm_runner
(
model_name
=
MODEL_NAME
,
revision
=
REVISION
,
revision
=
REVISION
,
dtype
=
"float16"
,
dtype
=
"float16"
,
max_model_len
=
MAX_MODEL_LEN
)
as
model
:
max_model_len
=
MAX_MODEL_LEN
)
as
vllm_
model
:
output
=
model
.
encode
(
"Write a short story about a robot that"
output
=
vllm_
model
.
encode
(
"Write a short story about a robot that"
" dreams for the first time.
\n
"
)
" dreams for the first time.
\n
"
)
model_config
=
model
.
model
.
llm_engine
.
model_config
model_config
=
vllm_model
.
model
.
llm_engine
.
model_config
model_tokenizer
=
vllm_model
.
model
.
llm_engine
.
tokenizer
model_tokenizer
=
model
.
model
.
llm_engine
.
tokenizer
# asserts on the bert model config file
# asserts on the bert model config file
assert
model_config
.
encoder_config
[
"max_seq_length"
]
==
512
assert
model_config
.
encoder_config
[
"max_seq_length"
]
==
512
...
@@ -46,11 +45,13 @@ def test_model_loading_with_params(vllm_runner):
...
@@ -46,11 +45,13 @@ def test_model_loading_with_params(vllm_runner):
assert
model_tokenizer
.
tokenizer_config
[
"do_lower_case"
]
assert
model_tokenizer
.
tokenizer_config
[
"do_lower_case"
]
assert
model_tokenizer
.
tokenizer
.
model_max_length
==
512
assert
model_tokenizer
.
tokenizer
.
model_max_length
==
512
model
=
model
.
model
.
llm_engine
.
model_executor
\
def
check_model
(
model
):
.
driver_worker
.
model_runner
.
model
assert
isinstance
(
model
,
BertEmbeddingModel
)
assert
isinstance
(
model
,
BertEmbeddingModel
)
assert
model
.
_pooler
.
pooling_type
==
PoolingType
.
CLS
assert
model
.
_pooler
.
pooling_type
==
PoolingType
.
CLS
assert
model
.
_pooler
.
normalize
assert
model
.
_pooler
.
normalize
vllm_model
.
apply_model
(
check_model
)
# assert output
# assert output
assert
output
assert
output
...
@@ -64,13 +65,12 @@ def test_roberta_model_loading_with_params(vllm_runner):
...
@@ -64,13 +65,12 @@ def test_roberta_model_loading_with_params(vllm_runner):
with
vllm_runner
(
model_name
=
MODEL_NAME_ROBERTA
,
with
vllm_runner
(
model_name
=
MODEL_NAME_ROBERTA
,
revision
=
REVISION_ROBERTA
,
revision
=
REVISION_ROBERTA
,
dtype
=
"float16"
,
dtype
=
"float16"
,
max_model_len
=
MAX_MODEL_LEN
)
as
model
:
max_model_len
=
MAX_MODEL_LEN
)
as
vllm_
model
:
output
=
model
.
encode
(
"Write a short story about a robot that"
output
=
vllm_
model
.
encode
(
"Write a short story about a robot that"
" dreams for the first time.
\n
"
)
" dreams for the first time.
\n
"
)
model_config
=
model
.
model
.
llm_engine
.
model_config
model_config
=
vllm_model
.
model
.
llm_engine
.
model_config
model_tokenizer
=
vllm_model
.
model
.
llm_engine
.
tokenizer
model_tokenizer
=
model
.
model
.
llm_engine
.
tokenizer
# asserts on the bert model config file
# asserts on the bert model config file
assert
model_config
.
encoder_config
[
"max_seq_length"
]
==
512
assert
model_config
.
encoder_config
[
"max_seq_length"
]
==
512
...
@@ -84,11 +84,12 @@ def test_roberta_model_loading_with_params(vllm_runner):
...
@@ -84,11 +84,12 @@ def test_roberta_model_loading_with_params(vllm_runner):
assert
model_tokenizer
.
tokenizer_id
==
"intfloat/multilingual-e5-large"
assert
model_tokenizer
.
tokenizer_id
==
"intfloat/multilingual-e5-large"
assert
not
model_tokenizer
.
tokenizer_config
[
"do_lower_case"
]
assert
not
model_tokenizer
.
tokenizer_config
[
"do_lower_case"
]
model
=
model
.
model
.
llm_engine
.
model_executor
\
def
check_model
(
model
):
.
driver_worker
.
model_runner
.
model
assert
isinstance
(
model
,
RobertaEmbeddingModel
)
assert
isinstance
(
model
,
RobertaEmbeddingModel
)
assert
model
.
_pooler
.
pooling_type
==
PoolingType
.
MEAN
assert
model
.
_pooler
.
pooling_type
==
PoolingType
.
MEAN
assert
model
.
_pooler
.
normalize
assert
model
.
_pooler
.
normalize
vllm_model
.
apply_model
(
check_model
)
# assert output
# assert output
assert
output
assert
output
...
@@ -103,17 +104,18 @@ def test_facebook_roberta_model_loading_with_params(vllm_runner):
...
@@ -103,17 +104,18 @@ def test_facebook_roberta_model_loading_with_params(vllm_runner):
model_name
=
"FacebookAI/roberta-base"
model_name
=
"FacebookAI/roberta-base"
with
vllm_runner
(
model_name
=
model_name
,
with
vllm_runner
(
model_name
=
model_name
,
dtype
=
"float16"
,
dtype
=
"float16"
,
max_model_len
=
MAX_MODEL_LEN
)
as
model
:
max_model_len
=
MAX_MODEL_LEN
)
as
vllm_
model
:
output
=
model
.
encode
(
"Write a short story about a robot that"
output
=
vllm_
model
.
encode
(
"Write a short story about a robot that"
" dreams for the first time.
\n
"
)
" dreams for the first time.
\n
"
)
model_tokenizer
=
model
.
model
.
llm_engine
.
tokenizer
model_tokenizer
=
vllm_
model
.
model
.
llm_engine
.
tokenizer
assert
model_tokenizer
.
tokenizer_id
==
model_name
assert
model_tokenizer
.
tokenizer_id
==
model_name
model
=
model
.
model
.
llm_engine
.
model_executor
\
def
check_model
(
model
):
.
driver_worker
.
model_runner
.
model
assert
isinstance
(
model
,
RobertaEmbeddingModel
)
assert
not
hasattr
(
model
,
"lm_head"
)
assert
not
hasattr
(
model
,
"lm_head"
)
assert
isinstance
(
model
,
RobertaEmbeddingModel
)
assert
isinstance
(
model
.
_pooler
,
CLSPool
)
assert
isinstance
(
model
.
_pooler
,
CLSPool
)
vllm_model
.
apply_model
(
check_model
)
assert
output
assert
output
tests/models/decoder_only/language/test_jamba.py
View file @
59a0192f
...
@@ -33,10 +33,13 @@ def test_models(
...
@@ -33,10 +33,13 @@ def test_models(
with
vllm_runner
(
model
,
dtype
=
dtype
)
as
vllm_model
:
with
vllm_runner
(
model
,
dtype
=
dtype
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
# This test is for verifying whether the model's extra_repr
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
# can be printed correctly.
print
(
vllm_model
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
def
print_model
(
model
):
model_runner
.
model
)
print
(
model
)
vllm_model
.
apply_model
(
print_model
)
for
i
in
range
(
len
(
example_prompts
)):
for
i
in
range
(
len
(
example_prompts
)):
hf_output_ids
,
hf_output_str
=
hf_outputs
[
i
]
hf_output_ids
,
hf_output_str
=
hf_outputs
[
i
]
...
...
tests/models/decoder_only/language/test_mamba.py
View file @
59a0192f
...
@@ -51,10 +51,13 @@ def test_models(
...
@@ -51,10 +51,13 @@ def test_models(
with
vllm_runner
(
model
,
dtype
=
dtype
)
as
vllm_model
:
with
vllm_runner
(
model
,
dtype
=
dtype
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
# This test is for verifying whether the model's extra_repr
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
# can be printed correctly.
print
(
vllm_model
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
def
print_model
(
model
):
model_runner
.
model
)
print
(
model
)
vllm_model
.
apply_model
(
print_model
)
for
i
in
range
(
len
(
example_prompts
)):
for
i
in
range
(
len
(
example_prompts
)):
hf_output_ids
,
hf_output_str
=
hf_outputs
[
i
]
hf_output_ids
,
hf_output_str
=
hf_outputs
[
i
]
...
...
tests/models/decoder_only/language/test_models.py
View file @
59a0192f
...
@@ -73,10 +73,13 @@ def test_models(
...
@@ -73,10 +73,13 @@ def test_models(
with
vllm_runner
(
model
,
dtype
=
dtype
)
as
vllm_model
:
with
vllm_runner
(
model
,
dtype
=
dtype
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_greedy_logprobs
(
vllm_outputs
=
vllm_model
.
generate_greedy_logprobs
(
example_prompts
,
max_tokens
,
num_logprobs
)
example_prompts
,
max_tokens
,
num_logprobs
)
# This test is for verifying whether the model's extra_repr
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
# can be printed correctly.
print
(
vllm_model
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
def
print_model
(
model
):
model_runner
.
model
)
print
(
model
)
vllm_model
.
apply_model
(
print_model
)
check_logprobs_close
(
check_logprobs_close
(
outputs_0_lst
=
hf_outputs
,
outputs_0_lst
=
hf_outputs
,
...
...
tests/models/decoder_only/vision_language/test_qwen2_vl.py
View file @
59a0192f
...
@@ -5,7 +5,6 @@ import pytest
...
@@ -5,7 +5,6 @@ import pytest
import
torch
import
torch
from
PIL
import
Image
from
PIL
import
Image
from
vllm.entrypoints.llm
import
LLM
from
vllm.multimodal.image
import
rescale_image_size
from
vllm.multimodal.image
import
rescale_image_size
from
vllm.multimodal.video
import
rescale_video_size
,
sample_frames_from_video
from
vllm.multimodal.video
import
rescale_video_size
,
sample_frames_from_video
...
@@ -69,7 +68,7 @@ class Qwen2VLPromptVideoEmbeddingInput(TypedDict):
...
@@ -69,7 +68,7 @@ class Qwen2VLPromptVideoEmbeddingInput(TypedDict):
def
batch_make_image_embeddings
(
def
batch_make_image_embeddings
(
image_batches
:
List
[
Union
[
Image
.
Image
,
List
[
Image
.
Image
]]],
processor
,
image_batches
:
List
[
Union
[
Image
.
Image
,
List
[
Image
.
Image
]]],
processor
,
llm
:
LLM
)
->
List
[
Qwen2VLPromptImageEmbeddingInput
]:
llm
:
VllmRunner
)
->
List
[
Qwen2VLPromptImageEmbeddingInput
]:
"""batched image embeddings for Qwen2-VL
"""batched image embeddings for Qwen2-VL
This will infer all images' embeddings in a single batch,
This will infer all images' embeddings in a single batch,
...
@@ -106,16 +105,18 @@ def batch_make_image_embeddings(
...
@@ -106,16 +105,18 @@ def batch_make_image_embeddings(
image_grid_thw
=
preprocess_result
[
"image_grid_thw"
]
image_grid_thw
=
preprocess_result
[
"image_grid_thw"
]
# pixel values to embeddings & grid_thws
# pixel values to embeddings & grid_thws
with
torch
.
no_grad
(
):
def
get_image_embeds
(
model
):
visual
=
llm
.
llm_engine
.
model_executor
.
driver_worker
.
\
with
torch
.
no_grad
():
model_runner
.
model
.
visual
visual
=
model
.
visual
pixel_values_on_device
=
pixel_values
.
to
(
visual
.
device
,
pixel_values_on_device
=
pixel_values
.
to
(
visual
.
device
,
dtype
=
visual
.
dtype
)
dtype
=
visual
.
dtype
)
image_grid_thw_on_device
=
image_grid_thw
.
to
(
visual
.
device
,
image_grid_thw_on_device
=
image_grid_thw
.
to
(
visual
.
device
,
dtype
=
torch
.
int64
)
dtype
=
torch
.
int64
)
image_embeds
=
visual
(
pixel_values_on_device
,
return
visual
(
pixel_values_on_device
,
grid_thw
=
image_grid_thw_on_device
)
grid_thw
=
image_grid_thw_on_device
)
image_embeds
=
torch
.
concat
(
llm
.
apply_model
(
get_image_embeds
))
# split into original batches
# split into original batches
result
:
List
[
Qwen2VLPromptImageEmbeddingInput
]
=
[]
result
:
List
[
Qwen2VLPromptImageEmbeddingInput
]
=
[]
...
@@ -150,7 +151,7 @@ def batch_make_image_embeddings(
...
@@ -150,7 +151,7 @@ def batch_make_image_embeddings(
def
batch_make_video_embeddings
(
def
batch_make_video_embeddings
(
video_batches
:
PromptVideoInput
,
processor
,
video_batches
:
PromptVideoInput
,
processor
,
llm
:
LLM
)
->
List
[
Qwen2VLPromptVideoEmbeddingInput
]:
llm
:
VllmRunner
)
->
List
[
Qwen2VLPromptVideoEmbeddingInput
]:
"""batched video embeddings for Qwen2-VL
"""batched video embeddings for Qwen2-VL
A NDArray represents a single video's all frames.
A NDArray represents a single video's all frames.
...
@@ -187,16 +188,18 @@ def batch_make_video_embeddings(
...
@@ -187,16 +188,18 @@ def batch_make_video_embeddings(
video_grid_thw
=
preprocess_result
[
"video_grid_thw"
]
video_grid_thw
=
preprocess_result
[
"video_grid_thw"
]
# pixel values to embeddings & grid_thws
# pixel values to embeddings & grid_thws
with
torch
.
no_grad
():
def
get_image_embeds
(
model
):
visual
=
llm
.
llm_engine
.
model_executor
.
driver_worker
.
\
with
torch
.
no_grad
():
model_runner
.
model
.
visual
visual
=
model
.
visual
pixel_values_on_device
=
pixel_values
.
to
(
visual
.
device
,
dtype
=
visual
.
dtype
)
video_grid_thw_on_device
=
video_grid_thw
.
to
(
visual
.
device
,
dtype
=
torch
.
int64
)
return
visual
(
pixel_values_on_device
,
grid_thw
=
video_grid_thw_on_device
)
pixel_values_on_device
=
pixel_values
.
to
(
visual
.
device
,
video_embeds
=
torch
.
concat
(
llm
.
apply_model
(
get_image_embeds
))
dtype
=
visual
.
dtype
)
video_grid_thw_on_device
=
video_grid_thw
.
to
(
visual
.
device
,
dtype
=
torch
.
int64
)
video_embeds
=
visual
(
pixel_values_on_device
,
grid_thw
=
video_grid_thw_on_device
)
# split into original batches
# split into original batches
result
:
List
[
Qwen2VLPromptVideoEmbeddingInput
]
=
[]
result
:
List
[
Qwen2VLPromptVideoEmbeddingInput
]
=
[]
...
@@ -278,9 +281,9 @@ def run_embedding_input_test(
...
@@ -278,9 +281,9 @@ def run_embedding_input_test(
max_tokens
,
max_tokens
,
num_logprobs
=
num_logprobs
,
num_logprobs
=
num_logprobs
,
images
=
batch_make_image_embeddings
(
images
=
batch_make_image_embeddings
(
images
,
processor
,
vllm_model
.
model
)
if
images
else
None
,
images
,
processor
,
vllm_model
)
if
images
else
None
,
videos
=
batch_make_video_embeddings
(
videos
=
batch_make_video_embeddings
(
videos
,
processor
,
vllm_model
.
model
)
if
videos
else
None
)
videos
,
processor
,
vllm_model
)
if
videos
else
None
)
for
prompts
,
images
,
videos
in
inputs
for
prompts
,
images
,
videos
in
inputs
]
]
...
...
tests/models/embedding/language/test_cls_models.py
View file @
59a0192f
...
@@ -24,10 +24,13 @@ def test_classification_models(
...
@@ -24,10 +24,13 @@ def test_classification_models(
)
->
None
:
)
->
None
:
with
vllm_runner
(
model
,
dtype
=
dtype
)
as
vllm_model
:
with
vllm_runner
(
model
,
dtype
=
dtype
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
classify
(
example_prompts
)
vllm_outputs
=
vllm_model
.
classify
(
example_prompts
)
# This test is for verifying whether the model's extra_repr
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
# can be printed correctly.
print
(
vllm_model
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
def
print_model
(
model
):
model_runner
.
model
)
print
(
model
)
vllm_model
.
apply_model
(
print_model
)
with
hf_runner
(
model
,
with
hf_runner
(
model
,
dtype
=
dtype
,
dtype
=
dtype
,
...
...
tests/models/embedding/language/test_embedding.py
View file @
59a0192f
...
@@ -62,10 +62,13 @@ def test_models(
...
@@ -62,10 +62,13 @@ def test_models(
max_model_len
=
None
,
max_model_len
=
None
,
**
vllm_extra_kwargs
)
as
vllm_model
:
**
vllm_extra_kwargs
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
encode
(
example_prompts
)
vllm_outputs
=
vllm_model
.
encode
(
example_prompts
)
# This test is for verifying whether the model's extra_repr
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
# can be printed correctly.
print
(
vllm_model
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
def
print_model
(
model
):
model_runner
.
model
)
print
(
model
)
vllm_model
.
apply_model
(
print_model
)
check_embeddings_close
(
check_embeddings_close
(
embeddings_0_lst
=
hf_outputs
,
embeddings_0_lst
=
hf_outputs
,
...
...
tests/quantization/test_compressed_tensors.py
View file @
59a0192f
...
@@ -30,50 +30,55 @@ from vllm.platforms import current_platform
...
@@ -30,50 +30,55 @@ from vllm.platforms import current_platform
def
test_compressed_tensors_w8a8_static_setup
(
vllm_runner
,
model_args
):
def
test_compressed_tensors_w8a8_static_setup
(
vllm_runner
,
model_args
):
model_path
,
strategy
,
quant_type
,
shape_0
,
is_symmetric
=
model_args
model_path
,
strategy
,
quant_type
,
shape_0
,
is_symmetric
=
model_args
with
vllm_runner
(
model_path
,
enforce_eager
=
True
)
as
llm
:
with
vllm_runner
(
model_path
,
enforce_eager
=
True
)
as
llm
:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
layer
=
model
.
model
.
layers
[
0
]
def
check_model
(
model
):
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
o_proj
=
layer
.
self_attn
.
o_proj
qkv_proj
=
layer
.
self_attn
.
qkv_proj
gate_up_proj
=
layer
.
mlp
.
gate_up_proj
o_proj
=
layer
.
self_attn
.
o_proj
down_proj
=
layer
.
mlp
.
down_proj
gate_up_proj
=
layer
.
mlp
.
gate_up_proj
down_proj
=
layer
.
mlp
.
down_proj
# assert zp for symmetric and asymmetric cases
def
zp_valid
(
zp
:
Optional
[
torch
.
Tensor
]):
# assert zp for symmetric and asymmetric cases
if
is_symmetric
:
def
zp_valid
(
zp
:
Optional
[
torch
.
Tensor
]):
return
zp
is
None
if
is_symmetric
:
return
zp
is
None
return
zp
is
not
None
and
zp
.
dtype
is
torch
.
int32
return
zp
is
not
None
and
zp
.
dtype
is
torch
.
int32
assert
zp_valid
(
qkv_proj
.
input_zero_point
)
assert
zp_valid
(
o_proj
.
input_zero_point
)
assert
zp_valid
(
qkv_proj
.
input_zero_point
)
assert
zp_valid
(
gate_up_proj
.
input_zero_point
)
assert
zp_valid
(
o_proj
.
input_zero_point
)
assert
zp_valid
(
down_proj
.
input_zero_point
)
assert
zp_valid
(
gate_up_proj
.
input_zero_point
)
assert
zp_valid
(
down_proj
.
input_zero_point
)
assert
isinstance
(
qkv_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
o_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
qkv_proj
.
quant_method
,
assert
isinstance
(
gate_up_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
CompressedTensorsLinearMethod
)
assert
isinstance
(
o_proj
.
quant_method
,
assert
isinstance
(
down_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
CompressedTensorsLinearMethod
)
assert
isinstance
(
gate_up_proj
.
quant_method
,
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensorsW8A8Int8
)
CompressedTensorsLinearMethod
)
assert
isinstance
(
down_proj
.
quant_method
,
assert
qkv_proj
.
scheme
.
strategy
==
strategy
CompressedTensorsLinearMethod
)
assert
qkv_proj
.
scheme
.
is_static_input_scheme
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensorsW8A8Int8
)
expected_type
=
torch
.
int8
assert
qkv_proj
.
scheme
.
strategy
==
strategy
assert
qkv_proj
.
weight
.
dtype
is
expected_type
assert
qkv_proj
.
scheme
.
is_static_input_scheme
assert
o_proj
.
weight
.
dtype
is
expected_type
expected_type
=
torch
.
int8
assert
gate_up_proj
.
weight
.
dtype
is
expected_type
assert
qkv_proj
.
weight
.
dtype
is
expected_type
if
qkv_proj
.
scheme
.
strategy
==
"tensor"
:
assert
o_proj
.
weight
.
dtype
is
expected_type
# Make sure it is a channelwise buffer
assert
gate_up_proj
.
weight
.
dtype
is
expected_type
# After running process_weights_after_loading
assert
len
(
qkv_proj
.
weight_scale
.
shape
)
==
2
if
qkv_proj
.
scheme
.
strategy
==
"tensor"
:
assert
qkv_proj
.
weight_scale
.
shape
[
0
]
==
shape_0
# Make sure it is a channelwise buffer
assert
qkv_proj
.
weight_scale
.
shape
[
1
]
==
1
# After running process_weights_after_loading
assert
qkv_proj
.
weight_scale
.
dtype
is
torch
.
float32
assert
len
(
qkv_proj
.
weight_scale
.
shape
)
==
2
assert
qkv_proj
.
input_scale
.
dtype
is
torch
.
float32
assert
qkv_proj
.
weight_scale
.
shape
[
0
]
==
shape_0
assert
qkv_proj
.
weight_scale
.
shape
[
1
]
==
1
assert
qkv_proj
.
weight_scale
.
dtype
is
torch
.
float32
assert
qkv_proj
.
input_scale
.
dtype
is
torch
.
float32
llm
.
apply_model
(
check_model
)
output
=
llm
.
generate_greedy
([
"Hello my name is"
],
max_tokens
=
20
)
output
=
llm
.
generate_greedy
([
"Hello my name is"
],
max_tokens
=
20
)
assert
output
assert
output
...
@@ -129,16 +134,20 @@ def test_compressed_tensors_no_enforce_eager(vllm_runner):
...
@@ -129,16 +134,20 @@ def test_compressed_tensors_no_enforce_eager(vllm_runner):
def
test_compressed_tensors_w8a8_dynamic_per_token
(
vllm_runner
,
model_args
):
def
test_compressed_tensors_w8a8_dynamic_per_token
(
vllm_runner
,
model_args
):
model_path
,
strategy
=
model_args
model_path
,
strategy
=
model_args
with
vllm_runner
(
model_path
,
dtype
=
torch
.
float16
)
as
llm
:
with
vllm_runner
(
model_path
,
dtype
=
torch
.
float16
)
as
llm
:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
def
check_model
(
model
):
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
isinstance
(
qkv_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensorsW8A8Int8
)
assert
not
qkv_proj
.
scheme
.
is_static_input_scheme
assert
qkv_proj
.
scheme
.
strategy
==
strategy
assert
qkv_proj
.
weight
.
dtype
is
torch
.
int8
assert
isinstance
(
qkv_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
llm
.
apply_model
(
check_model
)
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensorsW8A8Int8
)
assert
not
qkv_proj
.
scheme
.
is_static_input_scheme
assert
qkv_proj
.
scheme
.
strategy
==
strategy
assert
qkv_proj
.
weight
.
dtype
is
torch
.
int8
output
=
llm
.
generate_greedy
([
"Hello my name is"
],
max_tokens
=
20
)
output
=
llm
.
generate_greedy
([
"Hello my name is"
],
max_tokens
=
20
)
assert
output
assert
output
...
@@ -152,19 +161,24 @@ def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args):
...
@@ -152,19 +161,24 @@ def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args):
def
test_compressed_tensors_wNa16
(
vllm_runner
,
wNa16_args
):
def
test_compressed_tensors_wNa16
(
vllm_runner
,
wNa16_args
):
model
,
strategy
,
group
,
pack_factor
=
wNa16_args
model
,
strategy
,
group
,
pack_factor
=
wNa16_args
with
vllm_runner
(
model
)
as
llm
:
with
vllm_runner
(
model
)
as
llm
:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
def
check_model
(
model
):
assert
isinstance
(
qkv_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
layer
=
model
.
model
.
layers
[
0
]
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensorsWNA16
)
assert
qkv_proj
.
scheme
.
strategy
==
strategy
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
qkv_proj
.
scheme
.
group_size
==
(
-
1
if
group
is
None
else
group
)
assert
isinstance
(
qkv_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensorsWNA16
)
assert
qkv_proj
.
weight_packed
.
dtype
is
torch
.
int32
assert
qkv_proj
.
scheme
.
strategy
==
strategy
assert
qkv_proj
.
weight_scale
.
dtype
is
torch
.
float16
assert
qkv_proj
.
scheme
.
group_size
==
(
-
1
assert
qkv_proj
.
scheme
.
pack_factor
==
pack_factor
if
group
is
None
else
group
)
assert
qkv_proj
.
weight_packed
.
dtype
is
torch
.
int32
assert
qkv_proj
.
weight_scale
.
dtype
is
torch
.
float16
assert
qkv_proj
.
scheme
.
pack_factor
==
pack_factor
llm
.
apply_model
(
check_model
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
assert
output
assert
output
...
@@ -173,14 +187,18 @@ def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
...
@@ -173,14 +187,18 @@ def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
def
test_compressed_tensors_w4a16_marlin24
(
vllm_runner
):
def
test_compressed_tensors_w4a16_marlin24
(
vllm_runner
):
model_path
=
"nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t"
model_path
=
"nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t"
with
vllm_runner
(
model_path
)
as
llm
:
with
vllm_runner
(
model_path
)
as
llm
:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
def
check_model
(
model
):
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
isinstance
(
qkv_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
qkv_proj
.
quant_method
,
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensorsW4A16Sparse24
)
CompressedTensorsLinearMethod
)
assert
qkv_proj
.
weight_packed
.
dtype
is
torch
.
int32
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensorsW4A16Sparse24
)
assert
qkv_proj
.
weight_packed
.
dtype
is
torch
.
int32
llm
.
apply_model
(
check_model
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
assert
output
assert
output
...
@@ -189,23 +207,27 @@ def test_compressed_tensors_w4a16_marlin24(vllm_runner):
...
@@ -189,23 +207,27 @@ def test_compressed_tensors_w4a16_marlin24(vllm_runner):
def
test_compressed_tensors_fp8
(
vllm_runner
):
def
test_compressed_tensors_fp8
(
vllm_runner
):
model_path
=
"nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test"
model_path
=
"nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test"
with
vllm_runner
(
model_path
)
as
llm
:
with
vllm_runner
(
model_path
)
as
llm
:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
def
check_model
(
model
):
layer
=
model
.
model
.
layers
[
0
]
assert
isinstance
(
qkv_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
isinstance
(
qkv_proj
.
scheme
,
(
CompressedTensorsW8A8Fp8
,
CompressedTensorsW8A16Fp8
))
assert
qkv_proj
.
input_scale
.
dtype
is
torch
.
float32
assert
isinstance
(
qkv_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
qkv_proj
.
scheme
,
(
CompressedTensorsW8A8Fp8
,
CompressedTensorsW8A16Fp8
))
if
isinstance
(
qkv_proj
.
scheme
,
CompressedTensorsW8A8Fp8
):
assert
qkv_proj
.
input_scale
.
dtype
is
torch
.
float32
assert
len
(
qkv_proj
.
input_scale
.
shape
)
==
0
assert
qkv_proj
.
weight
.
dtype
is
torch
.
float8_e4m3fn
if
isinstance
(
qkv_proj
.
scheme
,
CompressedTensorsW8A8Fp8
):
assert
qkv_proj
.
weight_scale
.
dtype
is
torch
.
float32
assert
len
(
qkv_proj
.
input_scale
.
shape
)
==
0
assert
len
(
qkv_proj
.
weight_scale
.
shape
)
==
0
assert
qkv_proj
.
weight
.
dtype
is
torch
.
float8_e4m3fn
assert
qkv_proj
.
weight_scale
.
dtype
is
torch
.
float32
assert
len
(
qkv_proj
.
weight_scale
.
shape
)
==
0
llm
.
apply_model
(
check_model
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
assert
output
assert
output
...
@@ -248,12 +270,15 @@ def _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy):
...
@@ -248,12 +270,15 @@ def _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy):
def
test_compressed_tensors_2of4_quant_fp8
(
vllm_runner
,
args_2of4
):
def
test_compressed_tensors_2of4_quant_fp8
(
vllm_runner
,
args_2of4
):
model
,
weight_strategy
,
input_strategy
=
args_2of4
model
,
weight_strategy
,
input_strategy
=
args_2of4
with
vllm_runner
(
model
)
as
llm
:
with
vllm_runner
(
model
)
as
llm
:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
def
check_model
(
model
):
assert
qkv_proj
.
scheme
.
weights_dtype
==
torch
.
float8_e4m3fn
layer
=
model
.
model
.
layers
[
0
]
_test_2of4_quant_models
(
qkv_proj
,
weight_strategy
,
input_strategy
)
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
qkv_proj
.
scheme
.
weights_dtype
==
torch
.
float8_e4m3fn
_test_2of4_quant_models
(
qkv_proj
,
weight_strategy
,
input_strategy
)
llm
.
apply_model
(
check_model
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
print
(
output
)
print
(
output
)
...
@@ -273,12 +298,15 @@ def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4):
...
@@ -273,12 +298,15 @@ def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4):
def
test_compressed_tensors_2of4_quant_int8
(
vllm_runner
,
args_2of4
):
def
test_compressed_tensors_2of4_quant_int8
(
vllm_runner
,
args_2of4
):
model
,
weight_strategy
,
input_strategy
=
args_2of4
model
,
weight_strategy
,
input_strategy
=
args_2of4
with
vllm_runner
(
model
)
as
llm
:
with
vllm_runner
(
model
)
as
llm
:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
def
check_model
(
model
):
assert
qkv_proj
.
scheme
.
weights_dtype
==
torch
.
int8
layer
=
model
.
model
.
layers
[
0
]
_test_2of4_quant_models
(
qkv_proj
,
weight_strategy
,
input_strategy
)
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
qkv_proj
.
scheme
.
weights_dtype
==
torch
.
int8
_test_2of4_quant_models
(
qkv_proj
,
weight_strategy
,
input_strategy
)
llm
.
apply_model
(
check_model
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
print
(
output
)
print
(
output
)
...
@@ -293,20 +321,24 @@ def test_compressed_tensors_2of4_quant_int8(vllm_runner, args_2of4):
...
@@ -293,20 +321,24 @@ def test_compressed_tensors_2of4_quant_int8(vllm_runner, args_2of4):
def
test_compressed_tensors_2of4_sparse
(
vllm_runner
,
args_2of4
):
def
test_compressed_tensors_2of4_sparse
(
vllm_runner
,
args_2of4
):
model
=
args_2of4
model
=
args_2of4
with
vllm_runner
(
model
)
as
llm
:
with
vllm_runner
(
model
)
as
llm
:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
layer
=
model
.
model
.
layers
[
0
]
def
check_model
(
model
):
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
isinstance
(
qkv_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensors24
)
assert
isinstance
(
qkv_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
qkv_proj
.
scheme
.
weight_quant
is
None
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensors24
)
assert
qkv_proj
.
scheme
.
input_quant
is
None
assert
not
qkv_proj
.
scheme
.
quantized
assert
qkv_proj
.
scheme
.
weight_quant
is
None
assert
qkv_proj
.
quant_method
.
quantization_config
.
sparsity_scheme_map
assert
qkv_proj
.
scheme
.
input_quant
is
None
sparsity_map
=
qkv_proj
.
quant_method
.
quantization_config
.
sparsity_scheme_map
# noqa: E501
assert
not
qkv_proj
.
scheme
.
quantized
assert
sparsity_map
.
get
(
"Linear"
).
format
==
"dense"
assert
qkv_proj
.
quant_method
.
quantization_config
.
sparsity_scheme_map
assert
sparsity_map
.
get
(
"Linear"
).
sparsity_structure
==
"2:4"
sparsity_map
=
qkv_proj
.
quant_method
.
quantization_config
.
sparsity_scheme_map
# noqa: E501
assert
sparsity_map
.
get
(
"Linear"
).
format
==
"dense"
assert
sparsity_map
.
get
(
"Linear"
).
sparsity_structure
==
"2:4"
llm
.
apply_model
(
check_model
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
print
(
output
)
print
(
output
)
...
...
tests/quantization/test_fp8.py
View file @
59a0192f
...
@@ -49,13 +49,17 @@ KV_CACHE_MODELS = [
...
@@ -49,13 +49,17 @@ KV_CACHE_MODELS = [
def
test_kv_cache_model_load_and_run
(
vllm_runner
,
model_id
:
str
):
def
test_kv_cache_model_load_and_run
(
vllm_runner
,
model_id
:
str
):
with
vllm_runner
(
model_id
,
kv_cache_dtype
=
"fp8"
)
as
llm
:
with
vllm_runner
(
model_id
,
kv_cache_dtype
=
"fp8"
)
as
llm
:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
def
check_model
(
model
):
attn
=
model
.
model
.
layers
[
0
].
self_attn
.
attn
attn
=
model
.
model
.
layers
[
0
].
self_attn
.
attn
assert
isinstance
(
attn
.
quant_method
,
Fp8KVCacheMethod
)
# NOTE: it is valid for scales to be 1.0 (default value), but we know
assert
isinstance
(
attn
.
quant_method
,
Fp8KVCacheMethod
)
# these checkpoints have scales < 1.0
assert
0.0
<
attn
.
_k_scale
<
1.0
# NOTE: it is valid for scales to be 1.0 (default value), but
assert
0.0
<
attn
.
_v_scale
<
1.0
# we know these checkpoints have scales < 1.0
assert
0.0
<
attn
.
_k_scale
<
1.0
assert
0.0
<
attn
.
_v_scale
<
1.0
llm
.
apply_model
(
check_model
)
# note: this does not test accuracy, just that we can run through
# note: this does not test accuracy, just that we can run through
# see lm-eval tests for accuracy
# see lm-eval tests for accuracy
...
@@ -77,22 +81,24 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
...
@@ -77,22 +81,24 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
quantization
=
"fp8"
,
quantization
=
"fp8"
,
kv_cache_dtype
=
kv_cache_dtype
)
as
llm
:
kv_cache_dtype
=
kv_cache_dtype
)
as
llm
:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
def
check_model
(
model
):
fc1
=
model
.
model
.
decoder
.
layers
[
0
].
fc1
fc1
=
model
.
model
.
decoder
.
layers
[
0
].
fc1
assert
isinstance
(
fc1
.
quant_method
,
Fp8LinearMethod
)
assert
isinstance
(
fc1
.
quant_method
,
Fp8LinearMethod
)
if
kv_cache_dtype
==
"fp8"
:
if
kv_cache_dtype
==
"fp8"
:
attn
=
model
.
model
.
decoder
.
layers
[
0
].
self_attn
.
attn
attn
=
model
.
model
.
decoder
.
layers
[
0
].
self_attn
.
attn
assert
isinstance
(
attn
.
quant_method
,
Fp8KVCacheMethod
)
assert
isinstance
(
attn
.
quant_method
,
Fp8KVCacheMethod
)
assert
attn
.
_k_scale
==
1.0
assert
attn
.
_k_scale
==
1.0
assert
attn
.
_v_scale
==
1.0
assert
attn
.
_v_scale
==
1.0
if
current_platform
.
has_device_capability
(
89
)
and
not
force_marlin
:
if
current_platform
.
has_device_capability
(
89
)
and
not
force_marlin
:
# For GPUs with hardware support, we keep weights in fp8
# For GPUs with hardware support, we keep weights in fp8
assert
fc1
.
weight
.
dtype
==
torch
.
float8_e4m3fn
assert
fc1
.
weight
.
dtype
==
torch
.
float8_e4m3fn
else
:
else
:
# For GPUs without hardware support, we pack the fp8 weights
# For GPUs without hardware support, we pack the fp8 weights
# for weight-only quantization using Marlin kernels
# for weight-only quantization using Marlin kernels
assert
fc1
.
weight
.
dtype
==
torch
.
int32
assert
fc1
.
weight
.
dtype
==
torch
.
int32
llm
.
apply_model
(
check_model
)
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"fp8"
),
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"fp8"
),
...
...
tests/quantization/test_lm_head.py
View file @
59a0192f
...
@@ -28,20 +28,23 @@ def test_lm_head(
...
@@ -28,20 +28,23 @@ def test_lm_head(
model_lm_head_quant
:
Tuple
[
str
,
bool
],
model_lm_head_quant
:
Tuple
[
str
,
bool
],
)
->
None
:
)
->
None
:
model
,
lm_head_quantized
=
model_lm_head_quant
model
,
lm_head_quantized
=
model_lm_head_quant
vllm_model
=
vllm_runner
(
model
,
dtype
=
torch
.
float16
,
max_model_len
=
2048
)
with
vllm_runner
(
model
,
dtype
=
torch
.
float16
,
lm_head_layer
=
(
vllm_model
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
max_model_len
=
2048
)
as
vllm_model
:
model_runner
.
model
.
lm_head
)
def
check_model
(
model
):
if
lm_head_quantized
:
lm_head_layer
=
model
.
lm_head
assert
isinstance
(
lm_head_layer
.
linear_method
,
if
lm_head_quantized
:
(
GPTQLinearMethod
,
GPTQMarlinLinearMethod
,
MarlinLinearMethod
))
assert
isinstance
(
lm_head_layer
.
linear_method
,
else
:
(
GPTQLinearMethod
,
GPTQMarlinLinearMethod
,
assert
isinstance
(
lm_head_layer
.
linear_method
,
MarlinLinearMethod
))
UnquantizedEmbeddingMethod
)
else
:
assert
isinstance
(
lm_head_layer
.
linear_method
,
print
(
UnquantizedEmbeddingMethod
)
vllm_model
.
generate_greedy
(
prompts
=
[
"Hello my name is"
],
max_tokens
=
10
)[
0
][
1
])
vllm_model
.
apply_model
(
check_model
)
del
vllm_model
print
(
vllm_model
.
generate_greedy
(
prompts
=
[
"Hello my name is"
],
max_tokens
=
10
)[
0
][
1
])
tests/quantization/test_quark.py
View file @
59a0192f
...
@@ -12,19 +12,22 @@ from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501
...
@@ -12,19 +12,22 @@ from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501
def
test_quark_fp8
(
vllm_runner
):
def
test_quark_fp8
(
vllm_runner
):
model_path
=
"amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test"
model_path
=
"amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test"
with
vllm_runner
(
model_path
)
as
llm
:
with
vllm_runner
(
model_path
)
as
llm
:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
def
check_model
(
model
):
layer
=
model
.
model
.
layers
[
0
]
assert
isinstance
(
qkv_proj
.
quant_method
,
QuarkLinearMethod
)
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
isinstance
(
qkv_proj
.
scheme
,
QuarkW8A8Fp8
)
if
isinstance
(
qkv_proj
.
scheme
,
QuarkW8A8Fp8
):
assert
isinstance
(
qkv_proj
.
quant_method
,
QuarkLinearMethod
)
assert
len
(
qkv_proj
.
input_scale
.
shape
)
==
0
assert
isinstance
(
qkv_proj
.
scheme
,
QuarkW8A8Fp8
)
assert
qkv_proj
.
weight
.
dtype
is
torch
.
float8_e4m3fn
#assert qkv_proj.weight.dtype is torch.float8_e4m3fnuz
if
isinstance
(
qkv_proj
.
scheme
,
QuarkW8A8Fp8
):
assert
len
(
qkv_proj
.
weight_scale
.
shape
)
==
0
assert
len
(
qkv_proj
.
input_scale
.
shape
)
==
0
assert
qkv_proj
.
weight
.
dtype
is
torch
.
float8_e4m3fn
#assert qkv_proj.weight.dtype is torch.float8_e4m3fnuz
assert
len
(
qkv_proj
.
weight_scale
.
shape
)
==
0
llm
.
apply_model
(
check_model
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
assert
output
assert
output
tests/tensorizer_loader/test_tensorizer.py
View file @
59a0192f
...
@@ -3,6 +3,7 @@ import json
...
@@ -3,6 +3,7 @@ import json
import
os
import
os
import
pathlib
import
pathlib
import
subprocess
import
subprocess
from
functools
import
partial
from
unittest.mock
import
MagicMock
,
patch
from
unittest.mock
import
MagicMock
,
patch
import
openai
import
openai
...
@@ -24,7 +25,6 @@ from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
...
@@ -24,7 +25,6 @@ from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
# yapf: enable
# yapf: enable
from
vllm.utils
import
PlaceholderModule
,
import_from_path
from
vllm.utils
import
PlaceholderModule
,
import_from_path
from
..conftest
import
VllmRunner
from
..utils
import
VLLM_PATH
,
RemoteOpenAIServer
from
..utils
import
VLLM_PATH
,
RemoteOpenAIServer
from
.conftest
import
retry_until_skip
from
.conftest
import
retry_until_skip
...
@@ -58,16 +58,6 @@ def is_curl_installed():
...
@@ -58,16 +58,6 @@ def is_curl_installed():
return
False
return
False
def
get_torch_model
(
vllm_runner
:
VllmRunner
):
return
vllm_runner
\
.
model
\
.
llm_engine
\
.
model_executor
\
.
driver_worker
\
.
model_runner
\
.
model
def
write_keyfile
(
keyfile_path
:
str
):
def
write_keyfile
(
keyfile_path
:
str
):
encryption_params
=
EncryptionParams
.
random
()
encryption_params
=
EncryptionParams
.
random
()
pathlib
.
Path
(
keyfile_path
).
parent
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
pathlib
.
Path
(
keyfile_path
).
parent
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
...
@@ -121,8 +111,10 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs(
...
@@ -121,8 +111,10 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs(
config_for_serializing
=
TensorizerConfig
(
tensorizer_uri
=
model_path
,
config_for_serializing
=
TensorizerConfig
(
tensorizer_uri
=
model_path
,
encryption_keyfile
=
key_path
)
encryption_keyfile
=
key_path
)
serialize_vllm_model
(
get_torch_model
(
vllm_model
),
config_for_serializing
)
vllm_model
.
apply_model
(
partial
(
serialize_vllm_model
,
tensorizer_config
=
config_for_serializing
))
config_for_deserializing
=
TensorizerConfig
(
tensorizer_uri
=
model_path
,
config_for_deserializing
=
TensorizerConfig
(
tensorizer_uri
=
model_path
,
encryption_keyfile
=
key_path
)
encryption_keyfile
=
key_path
)
...
@@ -175,8 +167,10 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
...
@@ -175,8 +167,10 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
with
vllm_runner
(
model_ref
,
)
as
vllm_model
:
with
vllm_runner
(
model_ref
,
)
as
vllm_model
:
model_path
=
tmp_path
/
(
model_ref
+
".tensors"
)
model_path
=
tmp_path
/
(
model_ref
+
".tensors"
)
serialize_vllm_model
(
get_torch_model
(
vllm_model
),
vllm_model
.
apply_model
(
TensorizerConfig
(
tensorizer_uri
=
model_path
))
partial
(
serialize_vllm_model
,
tensorizer_config
=
TensorizerConfig
(
tensorizer_uri
=
model_path
)))
with
vllm_runner
(
with
vllm_runner
(
model_ref
,
model_ref
,
...
@@ -215,8 +209,10 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
...
@@ -215,8 +209,10 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
with
vllm_runner
(
model_ref
,
)
as
vllm_model
:
with
vllm_runner
(
model_ref
,
)
as
vllm_model
:
model_path
=
tmp_path
/
(
model_ref
+
".tensors"
)
model_path
=
tmp_path
/
(
model_ref
+
".tensors"
)
serialize_vllm_model
(
get_torch_model
(
vllm_model
),
vllm_model
.
apply_model
(
TensorizerConfig
(
tensorizer_uri
=
model_path
))
partial
(
serialize_vllm_model
,
tensorizer_config
=
TensorizerConfig
(
tensorizer_uri
=
model_path
)))
model_loader_extra_config
=
{
model_loader_extra_config
=
{
"tensorizer_uri"
:
str
(
model_path
),
"tensorizer_uri"
:
str
(
model_path
),
...
@@ -337,7 +333,9 @@ def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
...
@@ -337,7 +333,9 @@ def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
with
vllm_runner
(
model_ref
)
as
vllm_model
:
with
vllm_runner
(
model_ref
)
as
vllm_model
:
outputs
=
vllm_model
.
generate
(
prompts
,
sampling_params
)
outputs
=
vllm_model
.
generate
(
prompts
,
sampling_params
)
serialize_vllm_model
(
get_torch_model
(
vllm_model
),
config
)
vllm_model
.
apply_model
(
partial
(
serialize_vllm_model
,
tensorizer_config
=
config
))
assert
is_vllm_tensorized
(
config
)
assert
is_vllm_tensorized
(
config
)
...
...
vllm/engine/llm_engine.py
View file @
59a0192f
...
@@ -5,10 +5,10 @@ from collections import deque
...
@@ -5,10 +5,10 @@ from collections import deque
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
functools
import
partial
from
functools
import
partial
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
ClassVar
,
Deque
,
Dict
,
from
typing
import
(
TYPE_CHECKING
,
Callable
,
ClassVar
,
Deque
,
Dict
,
Iterable
,
Iterable
,
List
,
Mapping
,
NamedTuple
,
Optional
)
List
,
Mapping
,
NamedTuple
,
Optional
)
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Set
,
Tuple
,
Type
,
Union
,
cast
,
overload
from
typing
import
Set
,
Type
,
Union
,
cast
,
overload
import
torch
import
torch
from
typing_extensions
import
TypeVar
,
deprecated
from
typing_extensions
import
TypeVar
,
deprecated
...
@@ -1818,17 +1818,6 @@ class LLMEngine:
...
@@ -1818,17 +1818,6 @@ class LLMEngine:
def
stop_profile
(
self
)
->
None
:
def
stop_profile
(
self
)
->
None
:
self
.
model_executor
.
stop_profile
()
self
.
model_executor
.
stop_profile
()
def
collective_rpc
(
self
,
method
:
Union
[
str
,
Callable
],
timeout
:
Optional
[
float
]
=
None
,
args
:
Tuple
=
(),
kwargs
:
Optional
[
Dict
]
=
None
)
->
List
[
Any
]:
"""
See LLM.collective_rpc for more details.
"""
return
self
.
model_executor
.
collective_rpc
(
method
,
timeout
,
args
,
kwargs
)
def
check_health
(
self
)
->
None
:
def
check_health
(
self
)
->
None
:
if
self
.
tokenizer
:
if
self
.
tokenizer
:
self
.
tokenizer
.
check_health
()
self
.
tokenizer
.
check_health
()
...
...
vllm/entrypoints/llm.py
View file @
59a0192f
...
@@ -5,8 +5,9 @@ from typing import (Any, Callable, ClassVar, Dict, List, Optional, Sequence,
...
@@ -5,8 +5,9 @@ from typing import (Any, Callable, ClassVar, Dict, List, Optional, Sequence,
Tuple
,
Type
,
Union
,
cast
,
overload
)
Tuple
,
Type
,
Union
,
cast
,
overload
)
import
cloudpickle
import
cloudpickle
import
torch.nn
as
nn
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
typing_extensions
import
deprecated
from
typing_extensions
import
TypeVar
,
deprecated
from
vllm
import
envs
from
vllm
import
envs
from
vllm.beam_search
import
(
BeamSearchInstance
,
BeamSearchOutput
,
from
vllm.beam_search
import
(
BeamSearchInstance
,
BeamSearchOutput
,
...
@@ -42,6 +43,8 @@ from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of
...
@@ -42,6 +43,8 @@ from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_R
=
TypeVar
(
"_R"
,
default
=
Any
)
class
LLM
:
class
LLM
:
"""An LLM for generating texts from given prompts and sampling parameters.
"""An LLM for generating texts from given prompts and sampling parameters.
...
@@ -464,25 +467,42 @@ class LLM:
...
@@ -464,25 +467,42 @@ class LLM:
return
self
.
engine_class
.
validate_outputs
(
outputs
,
RequestOutput
)
return
self
.
engine_class
.
validate_outputs
(
outputs
,
RequestOutput
)
def
collective_rpc
(
self
,
def
collective_rpc
(
self
,
method
:
Union
[
str
,
Callable
],
method
:
Union
[
str
,
Callable
[...,
_R
]
],
timeout
:
Optional
[
float
]
=
None
,
timeout
:
Optional
[
float
]
=
None
,
args
:
Tuple
=
(),
args
:
Tuple
=
(),
kwargs
:
Optional
[
Dict
]
=
None
)
->
List
[
Any
]:
kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
)
->
List
[
_R
]:
"""
Execute an RPC call on all workers.
Args:
method: Name of the worker method to execute, or a callable that
is serialized and sent to all workers to execute.
If the method is a callable, it should accept an additional
`self` argument, in addition to the arguments passed in `args`
and `kwargs`. The `self` argument will be the worker object.
timeout: Maximum time in seconds to wait for execution. Raises a
:exc:`TimeoutError` on timeout. `None` means wait indefinitely.
args: Positional arguments to pass to the worker method.
kwargs: Keyword arguments to pass to the worker method.
Returns:
A list containing the results from each worker.
Note:
It is recommended to use this API to only pass control messages,
and set up data-plane communication to pass data.
"""
executor
=
self
.
llm_engine
.
model_executor
return
executor
.
collective_rpc
(
method
,
timeout
,
args
,
kwargs
)
def
apply_model
(
self
,
func
:
Callable
[[
nn
.
Module
],
_R
])
->
list
[
_R
]:
"""
"""
Run a method on all workers, with homogeneous arguments.
Run a function directly on the model inside each worker,
The main extension point for the LLM entrypoint.
returning the result for each of them.
Users can provide custom worker class through `worker_cls`
argument, and implement new methods in the worker class.
Then, users can call the new methods through this API.
It is recommended to use this API to only pass control messages,
and set up data-plane communication to pass data.
The method can also be a callable, which will be serialized
and sent to all workers to execute.
If the method is a callable, it should accept an additional
`self` argument, in addition to the arguments passed in `args`
and `kwargs`. The `self` argument will be the worker object.
"""
"""
return
self
.
llm_engine
.
collective_rpc
(
method
,
timeout
,
args
,
kwargs
)
executor
=
self
.
llm_engine
.
model_executor
return
executor
.
apply_model
(
func
)
def
beam_search
(
def
beam_search
(
self
,
self
,
...
...
vllm/executor/executor_base.py
View file @
59a0192f
...
@@ -3,6 +3,9 @@ from abc import ABC, abstractmethod
...
@@ -3,6 +3,9 @@ from abc import ABC, abstractmethod
from
typing
import
(
Any
,
Awaitable
,
Callable
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
from
typing
import
(
Any
,
Awaitable
,
Callable
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Union
)
Union
)
import
torch.nn
as
nn
from
typing_extensions
import
TypeVar
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
...
@@ -11,9 +14,12 @@ from vllm.platforms import current_platform
...
@@ -11,9 +14,12 @@ from vllm.platforms import current_platform
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sequence
import
ExecuteModelRequest
,
PoolerOutput
from
vllm.sequence
import
ExecuteModelRequest
,
PoolerOutput
from
vllm.utils
import
make_async
from
vllm.utils
import
make_async
from
vllm.worker.worker_base
import
WorkerBase
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_R
=
TypeVar
(
"_R"
,
default
=
Any
)
class
ExecutorBase
(
ABC
):
class
ExecutorBase
(
ABC
):
"""Base class for all executors.
"""Base class for all executors.
...
@@ -44,22 +50,37 @@ class ExecutorBase(ABC):
...
@@ -44,22 +50,37 @@ class ExecutorBase(ABC):
@
abstractmethod
@
abstractmethod
def
_init_executor
(
self
)
->
None
:
def
_init_executor
(
self
)
->
None
:
pass
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
collective_rpc
(
self
,
def
collective_rpc
(
self
,
method
:
Union
[
str
,
Callable
],
method
:
Union
[
str
,
Callable
[...,
_R
]
],
timeout
:
Optional
[
float
]
=
None
,
timeout
:
Optional
[
float
]
=
None
,
args
:
Tuple
=
(),
args
:
Tuple
=
(),
kwargs
:
Optional
[
Dict
]
=
None
)
->
List
[
Any
]:
kwargs
:
Optional
[
Dict
[
str
,
Any
]
]
=
None
)
->
List
[
_R
]:
"""
"""
The main interface of the executor to run a method on all workers,
Execute an RPC call on all workers.
with homogeneous arguments.
If the args are heterogeneous, then we can pack them into a list,
Args:
and unpack them in the method of every worker, because every worker
method: Name of the worker method to execute, or a callable that
knows their own rank.
is serialized and sent to all workers to execute.
If the method is a callable, it should accept an additional
`self` argument, in addition to the arguments passed in `args`
and `kwargs`. The `self` argument will be the worker object.
timeout: Maximum time in seconds to wait for execution. Raises a
:exc:`TimeoutError` on timeout. `None` means wait indefinitely.
args: Positional arguments to pass to the worker method.
kwargs: Keyword arguments to pass to the worker method.
Returns:
A list containing the results from each worker.
Note:
It is recommended to use this API to only pass control messages,
and set up data-plane communication to pass data.
"""
"""
pass
raise
NotImplementedError
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
"""Determine the number of available blocks for the GPU KV cache and
"""Determine the number of available blocks for the GPU KV cache and
...
@@ -97,6 +118,17 @@ class ExecutorBase(ABC):
...
@@ -97,6 +118,17 @@ class ExecutorBase(ABC):
self
.
collective_rpc
(
"initialize_cache"
,
self
.
collective_rpc
(
"initialize_cache"
,
args
=
(
num_gpu_blocks
,
num_cpu_blocks
))
args
=
(
num_gpu_blocks
,
num_cpu_blocks
))
def
apply_model
(
self
,
func
:
Callable
[[
nn
.
Module
],
_R
])
->
list
[
_R
]:
"""
Run a function directly on the model inside each worker,
returning the result for each of them.
"""
def
rpc_func
(
worker
:
WorkerBase
)
->
_R
:
return
func
(
worker
.
get_model
())
return
self
.
collective_rpc
(
rpc_func
)
def
execute_model
(
def
execute_model
(
self
,
execute_model_req
:
ExecuteModelRequest
self
,
execute_model_req
:
ExecuteModelRequest
)
->
Optional
[
List
[
Union
[
SamplerOutput
,
PoolerOutput
]]]:
)
->
Optional
[
List
[
Union
[
SamplerOutput
,
PoolerOutput
]]]:
...
...
vllm/executor/mp_distributed_executor.py
View file @
59a0192f
...
@@ -148,7 +148,7 @@ class MultiprocessingDistributedExecutor(DistributedExecutorBase):
...
@@ -148,7 +148,7 @@ class MultiprocessingDistributedExecutor(DistributedExecutorBase):
async_run_tensor_parallel_workers_only
:
bool
=
False
,
async_run_tensor_parallel_workers_only
:
bool
=
False
,
max_concurrent_workers
:
Optional
[
int
]
=
None
,
max_concurrent_workers
:
Optional
[
int
]
=
None
,
**
kwargs
,
**
kwargs
,
)
->
Any
:
)
->
List
[
Any
]
:
"""Runs the given method on all workers.
"""Runs the given method on all workers.
Args:
Args:
...
...
vllm/model_executor/model_loader/tensorizer.py
View file @
59a0192f
...
@@ -459,16 +459,7 @@ def tensorize_vllm_model(engine_args: EngineArgs,
...
@@ -459,16 +459,7 @@ def tensorize_vllm_model(engine_args: EngineArgs,
stream
.
write
(
encryption_params
.
key
)
stream
.
write
(
encryption_params
.
key
)
engine
=
LLMEngine
.
from_engine_args
(
engine_args
)
engine
=
LLMEngine
.
from_engine_args
(
engine_args
)
if
tensorizer_config
.
_is_sharded
:
engine
.
model_executor
.
collective_rpc
(
# if the engine is a distributed engine (for tensor parallel) then each
"save_tensorized_model"
,
# worker shard needs to serialize its part of the model.
kwargs
=
dict
(
tensorizer_config
=
tensorizer_config
),
engine
.
model_executor
.
_run_workers
(
)
"save_tensorized_model"
,
tensorizer_config
=
tensorizer_config
,
)
else
:
# with a single worker, we can get to the underlying model directly
serialize_vllm_model
(
engine
.
model_executor
.
driver_worker
.
model_runner
.
model
,
tensorizer_config
,
)
vllm/spec_decode/ngram_worker.py
View file @
59a0192f
...
@@ -2,6 +2,7 @@ import weakref
...
@@ -2,6 +2,7 @@ import weakref
from
typing
import
List
,
Optional
,
Set
,
Tuple
from
typing
import
List
,
Optional
,
Set
,
Tuple
import
torch
import
torch
import
torch.nn
as
nn
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.sequence
import
ExecuteModelRequest
...
@@ -10,6 +11,10 @@ from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
...
@@ -10,6 +11,10 @@ from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
class
_DummyModel
(
nn
.
Module
):
pass
class
NGramWorker
(
NonLLMProposerWorkerBase
):
class
NGramWorker
(
NonLLMProposerWorkerBase
):
"""NGramWorker provides a light drafter without need for model.
"""NGramWorker provides a light drafter without need for model.
...
@@ -36,7 +41,6 @@ class NGramWorker(NonLLMProposerWorkerBase):
...
@@ -36,7 +41,6 @@ class NGramWorker(NonLLMProposerWorkerBase):
def
init_device
(
self
):
def
init_device
(
self
):
self
.
device
=
torch
.
device
(
f
"
{
self
.
device_type
}
:
{
self
.
local_rank
}
"
)
self
.
device
=
torch
.
device
(
f
"
{
self
.
device_type
}
:
{
self
.
local_rank
}
"
)
self
.
load_model
=
lambda
*
args
,
**
kwargs
:
None
# Current NGramWorker only supports Top1Proposer
# Current NGramWorker only supports Top1Proposer
self
.
_proposer
=
Top1Proposer
(
self
.
_proposer
=
Top1Proposer
(
...
@@ -45,6 +49,12 @@ class NGramWorker(NonLLMProposerWorkerBase):
...
@@ -45,6 +49,12 @@ class NGramWorker(NonLLMProposerWorkerBase):
vocab_size
=
self
.
vocab_size
,
vocab_size
=
self
.
vocab_size
,
)
)
def
load_model
(
self
)
->
None
:
pass
# Dummy
def
get_model
(
self
)
->
nn
.
Module
:
return
_DummyModel
()
def
sampler_output
(
def
sampler_output
(
self
,
self
,
execute_model_req
:
ExecuteModelRequest
,
execute_model_req
:
ExecuteModelRequest
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment