Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
008cf886
Unverified
Commit
008cf886
authored
Sep 04, 2024
by
Harsha vardhan manoj Bikki
Committed by
GitHub
Sep 04, 2024
Browse files
[Neuron] Adding support for adding/ overriding neuron configuration a… (#8062)
Co-authored-by:
Harsha Bikki
<
harbikh@amazon.com
>
parent
77d9e514
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
243 additions
and
42 deletions
+243
-42
examples/offline_inference_neuron_int8_quantization.py
examples/offline_inference_neuron_int8_quantization.py
+50
-0
vllm/config.py
vllm/config.py
+41
-28
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+14
-3
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+2
-0
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+3
-0
vllm/model_executor/layers/quantization/neuron_quant.py
vllm/model_executor/layers/quantization/neuron_quant.py
+67
-0
vllm/model_executor/model_loader/neuron.py
vllm/model_executor/model_loader/neuron.py
+57
-8
vllm/worker/neuron_model_runner.py
vllm/worker/neuron_model_runner.py
+9
-3
No files found.
examples/offline_inference_neuron_int8_quantization.py
0 → 100644
View file @
008cf886
import
os
from
vllm
import
LLM
,
SamplingParams
# creates XLA hlo graphs for all the context length buckets.
os
.
environ
[
'NEURON_CONTEXT_LENGTH_BUCKETS'
]
=
"128,512,1024,2048"
# creates XLA hlo graphs for all the token gen buckets.
os
.
environ
[
'NEURON_TOKEN_GEN_BUCKETS'
]
=
"128,512,1024,2048"
# Quantizes neuron model weight to int8 ,
# The default config for quantization is int8 dtype.
os
.
environ
[
'NEURON_QUANT_DTYPE'
]
=
"s8"
# Sample prompts.
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
# Create a sampling params object.
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
)
# Create an LLM.
llm
=
LLM
(
model
=
"TinyLlama/TinyLlama-1.1B-Chat-v1.0"
,
max_num_seqs
=
8
,
# The max_model_len and block_size arguments are required to be same as
# max sequence length when targeting neuron device.
# Currently, this is a known limitation in continuous batching support
# in transformers-neuronx.
# TODO(liangfu): Support paged-attention in transformers-neuronx.
max_model_len
=
2048
,
block_size
=
2048
,
# The device can be automatically detected when AWS Neuron SDK is installed.
# The device argument can be either unspecified for automated detection,
# or explicitly assigned.
device
=
"neuron"
,
quantization
=
"neuron_quant"
,
override_neuron_config
=
{
"cast_logits_dtype"
:
"bfloat16"
,
},
tensor_parallel_size
=
2
)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
# Print the outputs.
for
output
in
outputs
:
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
vllm/config.py
View file @
008cf886
import
enum
import
json
from
dataclasses
import
dataclass
,
field
,
fields
from
typing
import
(
TYPE_CHECKING
,
ClassVar
,
List
,
Mapping
,
Optional
,
Tuple
,
Type
,
Union
)
from
typing
import
(
TYPE_CHECKING
,
Any
,
ClassVar
,
Dict
,
List
,
Mapping
,
Optional
,
Tuple
,
Type
,
Union
)
import
torch
from
transformers
import
PretrainedConfig
...
...
@@ -115,35 +115,39 @@ class ModelConfig:
the model name will be the same as `model`.
limit_mm_per_prompt: Maximum number of data instances per modality
per prompt. Only applicable for multimodal models.
override_neuron_config: Initialize non default neuron config or
override default neuron config that are specific to Neuron devices,
this argument will be used to configure the neuron config that
can not be gathered from the vllm arguments.
"""
def
__init__
(
self
,
model
:
str
,
tokenizer
:
str
,
tokenizer_mode
:
str
,
trust_remote_code
:
bool
,
dtype
:
Union
[
str
,
torch
.
dtype
],
seed
:
int
,
revision
:
Optional
[
str
]
=
None
,
code_revision
:
Optional
[
str
]
=
None
,
rope_scaling
:
Optional
[
dict
]
=
None
,
rope_theta
:
Optional
[
float
]
=
None
,
tokenizer_revision
:
Optional
[
str
]
=
None
,
max_model_len
:
Optional
[
int
]
=
None
,
spec_target_max_model_len
:
Optional
[
int
]
=
None
,
quantization
:
Optional
[
str
]
=
None
,
quantization_param_path
:
Optional
[
str
]
=
None
,
enforce_eager
:
Optional
[
bool
]
=
None
,
max_context_len_to_capture
:
Optional
[
int
]
=
None
,
max_seq_len_to_capture
:
Optional
[
int
]
=
None
,
max_logprobs
:
int
=
20
,
disable_sliding_window
:
bool
=
False
,
skip_tokenizer_init
:
bool
=
False
,
served_model_name
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
limit_mm_per_prompt
:
Optional
[
Mapping
[
str
,
int
]]
=
None
,
use_async_output_proc
:
bool
=
True
,
)
->
None
:
self
,
model
:
str
,
tokenizer
:
str
,
tokenizer_mode
:
str
,
trust_remote_code
:
bool
,
dtype
:
Union
[
str
,
torch
.
dtype
],
seed
:
int
,
revision
:
Optional
[
str
]
=
None
,
code_revision
:
Optional
[
str
]
=
None
,
rope_scaling
:
Optional
[
dict
]
=
None
,
rope_theta
:
Optional
[
float
]
=
None
,
tokenizer_revision
:
Optional
[
str
]
=
None
,
max_model_len
:
Optional
[
int
]
=
None
,
spec_target_max_model_len
:
Optional
[
int
]
=
None
,
quantization
:
Optional
[
str
]
=
None
,
quantization_param_path
:
Optional
[
str
]
=
None
,
enforce_eager
:
Optional
[
bool
]
=
None
,
max_context_len_to_capture
:
Optional
[
int
]
=
None
,
max_seq_len_to_capture
:
Optional
[
int
]
=
None
,
max_logprobs
:
int
=
20
,
disable_sliding_window
:
bool
=
False
,
skip_tokenizer_init
:
bool
=
False
,
served_model_name
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
limit_mm_per_prompt
:
Optional
[
Mapping
[
str
,
int
]]
=
None
,
use_async_output_proc
:
bool
=
True
,
override_neuron_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
)
->
None
:
self
.
model
=
model
self
.
tokenizer
=
tokenizer
self
.
tokenizer_mode
=
tokenizer_mode
...
...
@@ -227,6 +231,9 @@ class ModelConfig:
limit_mm_per_prompt
)
if
not
self
.
skip_tokenizer_init
:
self
.
_verify_tokenizer_mode
()
self
.
override_neuron_config
=
override_neuron_config
if
is_neuron
(
)
else
None
self
.
_verify_embedding_mode
()
self
.
_verify_quantization
()
self
.
_verify_cuda_graph
()
...
...
@@ -275,6 +282,7 @@ class ModelConfig:
"experts_int8"
]
tpu_supported_quantization
=
[
"tpu_int8"
]
neuron_supported_quantization
=
[
"neuron_quant"
]
if
self
.
quantization
is
not
None
:
self
.
quantization
=
self
.
quantization
.
lower
()
...
...
@@ -329,6 +337,11 @@ class ModelConfig:
"Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
" is not set, enabling VLLM_USE_TRITON_AWQ."
)
envs
.
VLLM_USE_TRITON_AWQ
=
True
if
is_neuron
(
)
and
self
.
quantization
not
in
neuron_supported_quantization
:
raise
ValueError
(
f
"
{
self
.
quantization
}
quantization is currently not "
f
"supported in Neuron Backend."
)
def
_verify_cuda_graph
(
self
)
->
None
:
if
self
.
max_seq_len_to_capture
is
None
:
...
...
vllm/engine/arg_utils.py
View file @
008cf886
...
...
@@ -2,8 +2,8 @@ import argparse
import
dataclasses
import
json
from
dataclasses
import
dataclass
from
typing
import
(
TYPE_CHECKING
,
Dict
,
List
,
Mapping
,
Optional
,
Tuple
,
Type
,
Union
)
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Mapping
,
Optional
,
Tuple
,
Type
,
Union
)
import
torch
...
...
@@ -149,6 +149,7 @@ class EngineArgs:
otlp_traces_endpoint
:
Optional
[
str
]
=
None
collect_detailed_traces
:
Optional
[
str
]
=
None
disable_async_output_proc
:
bool
=
False
override_neuron_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
def
__post_init__
(
self
):
if
self
.
tokenizer
is
None
:
...
...
@@ -742,6 +743,16 @@ class EngineArgs:
default
=
EngineArgs
.
disable_async_output_proc
,
help
=
"Disable async output processing. This may result in "
"lower performance."
)
parser
.
add_argument
(
'--override-neuron-config'
,
type
=
lambda
configs
:
{
str
(
key
):
value
for
key
,
value
in
(
config
.
split
(
':'
)
for
config
in
configs
.
split
(
','
))
},
default
=
None
,
help
=
"override or set neuron device configuration."
)
return
parser
@
classmethod
...
...
@@ -802,7 +813,7 @@ class EngineArgs:
served_model_name
=
self
.
served_model_name
,
limit_mm_per_prompt
=
self
.
limit_mm_per_prompt
,
use_async_output_proc
=
not
self
.
disable_async_output_proc
,
)
override_neuron_config
=
self
.
override_neuron_config
)
cache_config
=
CacheConfig
(
block_size
=
self
.
block_size
if
self
.
device
!=
"neuron"
else
self
.
max_model_len
,
# neuron needs block_size = max_model_len
...
...
vllm/engine/llm_engine.py
View file @
008cf886
...
...
@@ -214,6 +214,7 @@ class LLMEngine:
"Initializing an LLM engine (v%s) with config: "
"model=%r, speculative_config=%r, tokenizer=%r, "
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
"override_neuron_config=%s, "
"rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, "
"trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
"download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
...
...
@@ -232,6 +233,7 @@ class LLMEngine:
model_config
.
skip_tokenizer_init
,
model_config
.
tokenizer_mode
,
model_config
.
revision
,
model_config
.
override_neuron_config
,
model_config
.
rope_scaling
,
model_config
.
rope_theta
,
model_config
.
tokenizer_revision
,
...
...
vllm/model_executor/layers/quantization/__init__.py
View file @
008cf886
...
...
@@ -22,6 +22,8 @@ from vllm.model_executor.layers.quantization.gptq_marlin import (
from
vllm.model_executor.layers.quantization.gptq_marlin_24
import
(
GPTQMarlin24Config
)
from
vllm.model_executor.layers.quantization.marlin
import
MarlinConfig
from
vllm.model_executor.layers.quantization.neuron_quant
import
(
NeuronQuantConfig
)
from
vllm.model_executor.layers.quantization.qqq
import
QQQConfig
from
vllm.model_executor.layers.quantization.squeezellm
import
SqueezeLLMConfig
from
vllm.model_executor.layers.quantization.tpu_int8
import
Int8TpuConfig
...
...
@@ -46,6 +48,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"bitsandbytes"
:
BitsAndBytesConfig
,
"qqq"
:
QQQConfig
,
"experts_int8"
:
ExpertsInt8Config
,
"neuron_quant"
:
NeuronQuantConfig
,
}
...
...
vllm/model_executor/layers/quantization/neuron_quant.py
0 → 100644
View file @
008cf886
import
os
from
importlib.util
import
find_spec
from
typing
import
Any
,
Dict
,
List
,
Optional
from
torch.nn
import
Module
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
SUPPORTED_QUANT_DTYPE_LIST
=
[
's8'
,
'f8e4m3fn'
]
class
NeuronQuantConfig
(
QuantizationConfig
):
"""Int8 Quantization Config class for Neuron Backend."""
def
__init__
(
self
,
dequant_dtype
:
str
=
"f16"
,
quantize_method
:
str
=
"vector_dynamic"
,
)
->
None
:
self
.
quant_dtype
=
os
.
getenv
(
"NEURON_QUANT_DTYPE"
,
"s8"
)
if
self
.
quant_dtype
not
in
SUPPORTED_QUANT_DTYPE_LIST
:
raise
ValueError
(
f
"Neuron quantization datatype
{
self
.
quant_dtype
}
is not valid,"
f
"the quantization datatype should match one of the below types"
f
"
{
SUPPORTED_QUANT_DTYPE_LIST
}
"
)
self
.
dequant_dtype
=
dequant_dtype
self
.
quantize_method
=
quantize_method
def
get_name
(
self
)
->
str
:
return
"neuron_quant"
def
get_supported_act_dtypes
(
self
)
->
List
[
str
]:
return
SUPPORTED_QUANT_DTYPE_LIST
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
raise
NotImplementedError
(
"This function should not be called with Neuron Backend"
)
@
staticmethod
def
get_config_filenames
()
->
List
[
str
]:
return
[]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"NeuronQuantConfig"
:
quantize_method
=
cls
.
get_from_keys
(
config
,
[
"quantize_method"
])
dequant_dtype
=
cls
.
get_from_keys
(
config
,
[
"dequant_dtype"
])
return
cls
(
dequant_dtype
=
dequant_dtype
,
quantize_method
=
quantize_method
)
def
get_quant_method
(
self
,
layer
:
Module
,
prefix
:
str
)
->
Optional
[
Any
]:
if
find_spec
(
"transformers_neuronx"
)
is
not
None
:
return
self
.
get_quantization_config
()
else
:
raise
NotImplementedError
(
"Neuron Quantization is only supported through"
" transformers_neuronx."
)
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
def
get_quantization_config
(
self
):
from
transformers_neuronx.config
import
QuantizationConfig
return
QuantizationConfig
(
quant_dtype
=
self
.
quant_dtype
,
dequant_dtype
=
self
.
dequant_dtype
,
quantize_method
=
self
.
quantize_method
)
vllm/model_executor/model_loader/neuron.py
View file @
008cf886
...
...
@@ -10,6 +10,7 @@ from transformers import PretrainedConfig
from
vllm.config
import
ModelConfig
,
ParallelConfig
,
SchedulerConfig
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
get_quantization_config
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
...
...
@@ -81,8 +82,7 @@ class NeuronCasualLM(nn.Module):
neuronx_model_cls
=
getattr
(
neuronx_module
,
neuronx_model_cls_name
)
split_model_dir
=
f
"
{
model_name_or_path
}
-split"
if
os
.
path
.
isdir
(
os
.
path
.
join
(
model_name_or_path
,
"pytorch_model.bin"
)):
if
_is_pretrained_neuron_checkpoint
(
model_name_or_path
):
split_model_dir
=
model_name_or_path
elif
not
os
.
path
.
exists
(
f
"
{
model_name_or_path
}
-split"
):
hf_model_cls
=
getattr
(
transformers
,
hf_model_cls_name
)
...
...
@@ -97,6 +97,23 @@ class NeuronCasualLM(nn.Module):
self
.
model
.
to_neuron
()
def
_is_pretrained_neuron_checkpoint
(
model_name_or_path
:
str
)
->
bool
:
# Checking if the neuron checkpoint is saved in the old format.
if
os
.
path
.
isdir
(
os
.
path
.
join
(
model_name_or_path
,
"pytorch_model.bin"
)):
return
True
# Checking if the neuron checkpoint is saved in the new format.
pretrained_split_files
=
[
"config.json"
,
"generation_config.json"
]
pretrained_split_format
=
".safetensors"
for
file
in
pretrained_split_files
:
file_path
=
os
.
path
.
join
(
model_name_or_path
,
file
)
if
not
os
.
path
.
isfile
(
file_path
):
return
False
for
file
in
os
.
listdir
(
model_name_or_path
):
if
file
.
endswith
(
pretrained_split_format
):
return
True
return
False
def
_get_model_architecture
(
config
:
PretrainedConfig
)
->
str
:
architectures
=
getattr
(
config
,
"architectures"
,
[])
for
arch
in
architectures
:
...
...
@@ -119,19 +136,51 @@ def _get_buckets(env: str, default_value: List[int]) -> List[int]:
return
buckets_list
def
_get_default_neuron_config
(
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
):
from
transformers_neuronx.config
import
ContinuousBatchingConfig
from
transformers_neuronx.constants
import
LAYOUT_BSH
continuous_batching_config
=
ContinuousBatchingConfig
(
batch_size_for_shared_caches
=
scheduler_config
.
max_num_seqs
)
quant_config
=
dict
(
dequant_dtype
=
TORCH_DTYPE_TO_NEURON_AMP
[
model_config
.
dtype
],
quantize_method
=
"vector_dynamic"
)
neuron_quantization_config_builder
=
lambda
quant
:
get_quantization_config
(
quant
).
from_config
(
quant_config
).
get_quant_method
(
None
,
""
)
# TODO: Add Paged attention config to the default neuron arguments.
default_neuron_args
=
dict
(
collectives_layout
=
LAYOUT_BSH
,
attention_layout
=
LAYOUT_BSH
,
fuse_qkv
=
True
,
quant
=
neuron_quantization_config_builder
(
model_config
.
quantization
)
if
model_config
.
quantization
else
None
,
continuous_batching
=
continuous_batching_config
,
weight_tiling
=
bool
(
model_config
.
quantization
))
return
default_neuron_args
def
_get_neuron_config_after_override
(
default_neuron_config
,
overridden_neuron_config
):
from
transformers_neuronx.config
import
NeuronConfig
overridden_neuron_config
=
overridden_neuron_config
or
{}
default_neuron_config
.
update
(
overridden_neuron_config
)
return
NeuronConfig
(
**
default_neuron_config
)
def
get_neuron_model
(
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
)
->
nn
.
Module
:
from
transformers_neuronx.config
import
(
ContinuousBatchingConfig
,
NeuronConfig
)
# Create a model instance.
model
=
NeuronCasualLM
(
model_config
.
hf_config
)
continuous_batching_config
=
ContinuousBatchingConfig
(
batch_size_for_shared_caches
=
scheduler_config
.
max_num_seqs
)
neuron_config
=
NeuronConfig
(
continuous_batching
=
continuous_batching_config
)
default_neuron_config_args
=
_get_default_neuron_config
(
model_config
,
parallel_config
,
scheduler_config
)
neuron_config
=
_get_neuron_config_after_override
(
default_neuron_config_args
,
model_config
.
override_neuron_config
)
context_length_estimates
=
_get_buckets
(
"NEURON_CONTEXT_LENGTH_BUCKETS"
,
[
scheduler_config
.
max_model_len
])
...
...
vllm/worker/neuron_model_runner.py
View file @
008cf886
from
dataclasses
import
dataclass
from
importlib.util
import
find_spec
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
...
...
@@ -76,9 +77,14 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
self
.
model
:
nn
.
Module
# initialize after load_model.
def
load_model
(
self
)
->
None
:
self
.
model
=
get_neuron_model
(
self
.
model_config
,
parallel_config
=
self
.
parallel_config
,
scheduler_config
=
self
.
scheduler_config
)
if
find_spec
(
"transformers_neuronx"
)
is
not
None
:
self
.
model
=
get_neuron_model
(
self
.
model_config
,
parallel_config
=
self
.
parallel_config
,
scheduler_config
=
self
.
scheduler_config
)
else
:
raise
NotImplementedError
(
"Supports only Transformer-NeuronX based models."
)
def
_prepare_prompt
(
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment