Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cc90419e
Unverified
Commit
cc90419e
authored
Oct 04, 2024
by
Chongming Ni
Committed by
GitHub
Oct 04, 2024
Browse files
[Hardware][Neuron] Add on-device sampling support for Neuron (#8746)
Co-authored-by:
Ashraf Mahgoub
<
ashymahg@amazon.com
>
parent
27302dd5
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
128 additions
and
13 deletions
+128
-13
vllm/model_executor/model_loader/neuron.py
vllm/model_executor/model_loader/neuron.py
+50
-9
vllm/worker/neuron_model_runner.py
vllm/worker/neuron_model_runner.py
+78
-4
No files found.
vllm/model_executor/model_loader/neuron.py
View file @
cc90419e
"""Utilities for selecting and loading neuron models."""
"""Utilities for selecting and loading neuron models."""
import
copy
import
importlib
import
importlib
import
os
import
os
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
...
@@ -13,6 +14,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
...
@@ -13,6 +14,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from
vllm.model_executor.layers.quantization
import
get_quantization_config
from
vllm.model_executor.layers.quantization
import
get_quantization_config
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
Logprob
,
SequenceOutput
)
TORCH_DTYPE_TO_NEURON_AMP
=
{
TORCH_DTYPE_TO_NEURON_AMP
=
{
"auto"
:
"f32"
,
"auto"
:
"f32"
,
...
@@ -37,14 +40,17 @@ _NEURON_SUPPORTED_MODELS: Dict[str, Tuple[str, str, str]] = {
...
@@ -37,14 +40,17 @@ _NEURON_SUPPORTED_MODELS: Dict[str, Tuple[str, str, str]] = {
class
NeuronCasualLM
(
nn
.
Module
):
class
NeuronCasualLM
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
)
->
None
:
on_device_sampling_disabled
:
bool
=
False
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
,
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
,
logits_as_input
=
True
)
logits_as_input
=
True
)
self
.
on_device_sampling_disabled
=
on_device_sampling_disabled
if
self
.
on_device_sampling_disabled
:
# Use default sampler
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
# Lazy initialized
# Lazy initialized
...
@@ -71,9 +77,30 @@ class NeuronCasualLM(nn.Module):
...
@@ -71,9 +77,30 @@ class NeuronCasualLM(nn.Module):
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
if
self
.
on_device_sampling_disabled
:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
return
next_tokens
# On-device sampling outputs the token ids directly.
sampled_token_ids
=
logits
.
flatten
()
next_tokens
=
[]
sample_idx
=
0
for
seq_group
in
sampling_metadata
.
seq_groups
:
samples
=
[]
for
seq_id
in
seq_group
.
seq_ids
:
token_id
=
sampled_token_ids
[
sample_idx
].
item
()
samples
.
append
(
SequenceOutput
(
parent_seq_id
=
seq_id
,
output_token
=
token_id
,
logprobs
=
{
token_id
:
Logprob
(
token_id
)}))
sample_idx
+=
1
next_tokens
.
append
(
CompletionSequenceGroupOutput
(
samples
=
samples
,
prompt_logprobs
=
None
))
return
SamplerOutput
(
outputs
=
next_tokens
)
def
load_weights
(
self
,
model_name_or_path
:
str
,
**
kwargs
):
def
load_weights
(
self
,
model_name_or_path
:
str
,
**
kwargs
):
arch
=
_get_model_architecture
(
self
.
config
)
arch
=
_get_model_architecture
(
self
.
config
)
neuronx_module_path
,
neuronx_model_cls_name
,
hf_model_cls_name
=
(
neuronx_module_path
,
neuronx_model_cls_name
,
hf_model_cls_name
=
(
...
@@ -157,10 +184,22 @@ def _get_default_neuron_config(model_config: ModelConfig,
...
@@ -157,10 +184,22 @@ def _get_default_neuron_config(model_config: ModelConfig,
quant
=
neuron_quantization_config_builder
(
model_config
.
quantization
)
quant
=
neuron_quantization_config_builder
(
model_config
.
quantization
)
if
model_config
.
quantization
else
None
,
if
model_config
.
quantization
else
None
,
continuous_batching
=
continuous_batching_config
,
continuous_batching
=
continuous_batching_config
,
weight_tiling
=
bool
(
model_config
.
quantization
))
weight_tiling
=
bool
(
model_config
.
quantization
),
on_device_generation
=
_get_neuron_on_device_generation_config
(
model_config
))
return
default_neuron_args
return
default_neuron_args
def
_get_neuron_on_device_generation_config
(
model_config
:
ModelConfig
):
if
not
_is_neuron_on_device_sampling_disabled
(
model_config
):
return
copy
.
deepcopy
(
model_config
.
neuron_sampling_params
)
return
None
def
_is_neuron_on_device_sampling_disabled
(
model_config
:
ModelConfig
)
->
bool
:
return
not
getattr
(
model_config
,
"neuron_sampling_params"
,
None
)
def
_get_neuron_config_after_override
(
default_neuron_config
,
def
_get_neuron_config_after_override
(
default_neuron_config
,
overridden_neuron_config
):
overridden_neuron_config
):
from
transformers_neuronx.config
import
NeuronConfig
from
transformers_neuronx.config
import
NeuronConfig
...
@@ -174,7 +213,9 @@ def get_neuron_model(model_config: ModelConfig,
...
@@ -174,7 +213,9 @@ def get_neuron_model(model_config: ModelConfig,
scheduler_config
:
SchedulerConfig
)
->
nn
.
Module
:
scheduler_config
:
SchedulerConfig
)
->
nn
.
Module
:
# Create a model instance.
# Create a model instance.
model
=
NeuronCasualLM
(
model_config
.
hf_config
)
model
=
NeuronCasualLM
(
model_config
.
hf_config
,
_is_neuron_on_device_sampling_disabled
(
model_config
))
default_neuron_config_args
=
_get_default_neuron_config
(
default_neuron_config_args
=
_get_default_neuron_config
(
model_config
,
parallel_config
,
scheduler_config
)
model_config
,
parallel_config
,
scheduler_config
)
...
...
vllm/worker/neuron_model_runner.py
View file @
cc90419e
import
os
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
importlib.util
import
find_spec
from
importlib.util
import
find_spec
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers_neuronx.config
import
GenerationConfig
from
vllm.config
import
(
DeviceConfig
,
ModelConfig
,
ParallelConfig
,
from
vllm.config
import
(
DeviceConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
SchedulerConfig
)
...
@@ -50,6 +52,9 @@ class ModelInputForNeuron(ModelRunnerInputBase):
...
@@ -50,6 +52,9 @@ class ModelInputForNeuron(ModelRunnerInputBase):
class
NeuronModelRunner
(
ModelRunnerBase
[
ModelInputForNeuron
]):
class
NeuronModelRunner
(
ModelRunnerBase
[
ModelInputForNeuron
]):
# NEURON has an upper limit on the top_k
_MAX_NEURON_SAMPLING_TOP_K
=
256
def
__init__
(
def
__init__
(
self
,
self
,
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
...
@@ -76,6 +81,34 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
...
@@ -76,6 +81,34 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
# Lazy initialization.
# Lazy initialization.
self
.
model
:
nn
.
Module
# initialize after load_model.
self
.
model
:
nn
.
Module
# initialize after load_model.
# Once NEURON_ON_DEVICE_SAMPLING_DISABLED is set to a non-zero value,
# turn off on-device sampling.
self
.
_on_device_sampling_disabled
=
int
(
os
.
getenv
(
"NEURON_ON_DEVICE_SAMPLING_DISABLED"
,
"0"
))
# NEURON needs to update sampling parameters when request IDs change
# across batches. This variable stores the previous batch's request IDs
# to determine if an update is needed.
self
.
_previous_batch_request_ids
:
List
[
str
]
=
[]
if
not
self
.
_on_device_sampling_disabled
:
logger
.
warning
(
"On-device sampling is turned on in Neuron by default, only "
"top_k, top_p, and temperature are current supported sampling "
"parameters. To turn off the on-device sampling, please set "
"the environment variable NEURON_ON_DEVICE_SAMPLING_DISABLED=1."
)
self
.
model_config
.
neuron_sampling_params
=
GenerationConfig
(
max_length
=
self
.
scheduler_config
.
max_model_len
,
do_sample
=
True
,
per_batch_line
=
True
,
top_k
=
[
self
.
_MAX_NEURON_SAMPLING_TOP_K
]
\
*
self
.
scheduler_config
.
max_num_seqs
,
top_p
=
[
1.0
]
*
self
.
scheduler_config
.
max_num_seqs
,
temperature
=
[
1.0
]
*
self
.
scheduler_config
.
max_num_seqs
,
dynamic
=
True
,
global_top_k
=
self
.
_MAX_NEURON_SAMPLING_TOP_K
)
def
load_model
(
self
)
->
None
:
def
load_model
(
self
)
->
None
:
if
find_spec
(
"transformers_neuronx"
)
is
not
None
:
if
find_spec
(
"transformers_neuronx"
)
is
not
None
:
self
.
model
=
get_neuron_model
(
self
.
model
=
get_neuron_model
(
...
@@ -215,7 +248,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
...
@@ -215,7 +248,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
else
:
else
:
(
input_tokens
,
input_positions
,
(
input_tokens
,
input_positions
,
input_block_ids
)
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
input_block_ids
)
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
seq_lens
=
[]
seq_lens
=
None
sampling_metadata
=
SamplingMetadata
.
prepare
(
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
seq_group_metadata_list
,
seq_lens
,
seq_lens
,
...
@@ -227,12 +260,49 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
...
@@ -227,12 +260,49 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
self
.
pin_memory
,
self
.
pin_memory
,
generators
=
self
.
get_generators
(
finished_requests_ids
))
generators
=
self
.
get_generators
(
finished_requests_ids
))
if
not
self
.
_on_device_sampling_disabled
:
# Once the request IDs are changed in current iteration, we will
# update the on-device sampling parameters.
current_batch_request_ids
=
[
seq_group_meta_data
.
request_id
for
seq_group_meta_data
in
seq_group_metadata_list
]
if
current_batch_request_ids
!=
self
.
_previous_batch_request_ids
:
self
.
_update_neuron_sampling_params
(
sampling_metadata
)
self
.
_previous_batch_request_ids
=
current_batch_request_ids
return
ModelInputForNeuron
(
input_tokens
=
input_tokens
,
return
ModelInputForNeuron
(
input_tokens
=
input_tokens
,
input_positions
=
input_positions
,
input_positions
=
input_positions
,
input_block_ids
=
input_block_ids
,
input_block_ids
=
input_block_ids
,
sampling_metadata
=
sampling_metadata
,
sampling_metadata
=
sampling_metadata
,
multi_modal_kwargs
=
multi_modal_kwargs
)
multi_modal_kwargs
=
multi_modal_kwargs
)
def
_update_neuron_sampling_params
(
self
,
sampling_metadata
:
SamplingMetadata
):
# Update Neuron sampling parameters (GenerationConfig in Neuron)
current_sampling_params
=
self
.
model_config
.
neuron_sampling_params
assert
current_sampling_params
is
not
None
,
(
f
"Failed to update sampling_params, "
f
"current sampling params is
{
current_sampling_params
}
"
)
top_k
=
current_sampling_params
.
top_k
top_p
=
current_sampling_params
.
top_p
temperature
=
current_sampling_params
.
temperature
for
index
,
sequence_group_to_sample
in
enumerate
(
sampling_metadata
.
seq_groups
):
top_k
[
index
]
=
self
.
_convert_to_neuron_top_k
(
sequence_group_to_sample
.
sampling_params
.
top_k
)
top_p
[
index
]
=
sequence_group_to_sample
.
sampling_params
.
top_p
temperature
[
index
]
=
\
sequence_group_to_sample
.
sampling_params
.
temperature
self
.
model
.
model
.
update_generation_config
(
current_sampling_params
)
def
_convert_to_neuron_top_k
(
self
,
top_k
:
int
)
->
int
:
if
top_k
<
0
or
top_k
>
self
.
_MAX_NEURON_SAMPLING_TOP_K
:
return
self
.
_MAX_NEURON_SAMPLING_TOP_K
return
top_k
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
execute_model
(
def
execute_model
(
self
,
self
,
...
@@ -253,9 +323,13 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
...
@@ -253,9 +323,13 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
device
=
self
.
device
),
device
=
self
.
device
),
)
)
# Compute the logits.
# Compute the logits only if the on-device sampling is turned off as
# on-device sampling outputs the token ids.
if
self
.
_on_device_sampling_disabled
:
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
model_input
.
sampling_metadata
)
model_input
.
sampling_metadata
)
else
:
logits
=
hidden_states
# Sample the next token.
# Sample the next token.
output
=
self
.
model
.
sample
(
output
=
self
.
model
.
sample
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment