Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4172235a
Unverified
Commit
4172235a
authored
Sep 06, 2025
by
Woosuk Kwon
Committed by
GitHub
Sep 06, 2025
Browse files
[V0 deprecation] Deprecate V0 Neuron backend (#21159)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
848562bd
Changes
46
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
8 additions
and
2720 deletions
+8
-2720
tests/neuron/1_core/test_rotary_embedding.py
tests/neuron/1_core/test_rotary_embedding.py
+0
-68
tests/neuron/2_core/test_comm_ops.py
tests/neuron/2_core/test_comm_ops.py
+0
-101
tests/neuron/2_core/test_eagle.py
tests/neuron/2_core/test_eagle.py
+0
-83
tests/neuron/2_core/test_mistral.py
tests/neuron/2_core/test_mistral.py
+0
-64
tests/neuron/2_core/test_multi_lora.py
tests/neuron/2_core/test_multi_lora.py
+0
-97
vllm/attention/ops/nki_flash_attn.py
vllm/attention/ops/nki_flash_attn.py
+0
-903
vllm/collect_env.py
vllm/collect_env.py
+1
-15
vllm/config/__init__.py
vllm/config/__init__.py
+2
-20
vllm/config/cache.py
vllm/config/cache.py
+2
-3
vllm/config/parallel.py
vllm/config/parallel.py
+1
-4
vllm/distributed/device_communicators/neuron_communicator.py
vllm/distributed/device_communicators/neuron_communicator.py
+0
-20
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+0
-5
vllm/envs.py
vllm/envs.py
+1
-1
vllm/model_executor/custom_op.py
vllm/model_executor/custom_op.py
+0
-7
vllm/model_executor/layers/activation.py
vllm/model_executor/layers/activation.py
+0
-7
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+0
-3
vllm/model_executor/layers/quantization/neuron_quant.py
vllm/model_executor/layers/quantization/neuron_quant.py
+0
-76
vllm/model_executor/layers/rotary_embedding/base.py
vllm/model_executor/layers/rotary_embedding/base.py
+1
-82
vllm/model_executor/model_loader/neuron.py
vllm/model_executor/model_loader/neuron.py
+0
-476
vllm/model_executor/model_loader/neuronx_distributed.py
vllm/model_executor/model_loader/neuronx_distributed.py
+0
-685
No files found.
tests/neuron/1_core/test_rotary_embedding.py
deleted
100644 → 0
View file @
848562bd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for miscellaneous utilities
"""
import
pytest
import
torch
from
vllm.model_executor.layers.rotary_embedding
import
RotaryEmbedding
from
vllm.platforms
import
current_platform
@
pytest
.
mark
.
parametrize
(
"max_position,is_neox_style,rotary_dim,head_size,seq_len,use_key"
,
[
(
16
,
False
,
32
,
32
,
1024
,
True
),
(
16
,
False
,
32
,
128
,
1024
,
True
),
(
16
,
True
,
32
,
32
,
1024
,
True
),
(
16
,
True
,
32
,
128
,
1024
,
True
),
(
16
,
False
,
32
,
128
,
1024
,
False
),
(
16
,
True
,
32
,
128
,
1024
,
False
),
])
def
test_rotary_embedding_opcheck
(
max_position
,
is_neox_style
,
rotary_dim
,
head_size
,
seq_len
,
use_key
):
import
torch_xla.core.xla_model
as
xm
device
=
xm
.
xla_device
()
current_platform
.
seed_everything
(
0
)
torch
.
set_default_device
(
"cpu"
)
batch_size
=
1
base
=
10000
num_heads
=
8
rot
=
RotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
torch
.
float32
)
positions
=
torch
.
randint
(
0
,
max_position
,
(
batch_size
,
seq_len
),
device
=
"cpu"
)
query
=
torch
.
randn
(
batch_size
,
seq_len
,
num_heads
*
head_size
,
dtype
=
torch
.
float32
,
device
=
"cpu"
)
key
=
torch
.
randn_like
(
query
)
if
use_key
else
None
assert
positions
.
is_cpu
,
\
"reference input tensor is expected to be CPU tensor."
ref_query
,
ref_key
=
rot
.
to
(
device
=
"cpu"
).
forward_native
(
positions
,
query
,
key
)
out_query
,
out_key
=
rot
.
to
(
device
=
device
).
forward_neuron
(
positions
.
to
(
device
=
device
),
query
.
to
(
device
=
device
),
key
.
to
(
device
=
device
)
if
key
is
not
None
else
None
)
if
use_key
:
assert
out_query
.
is_xla
and
out_key
.
is_xla
,
\
"output tensor is expected to be XLA tensor"
torch
.
testing
.
assert_close
(
out_key
.
cpu
(),
ref_key
,
atol
=
1e-2
,
rtol
=
1e-2
)
else
:
assert
out_key
is
None
,
"expected returned key to be None"
assert
out_query
.
is_xla
,
\
"output tensor is expected to be XLA tensor"
torch
.
testing
.
assert_close
(
out_query
.
cpu
(),
ref_query
,
atol
=
1e-2
,
rtol
=
1e-2
)
tests/neuron/2_core/test_comm_ops.py
deleted
100644 → 0
View file @
848562bd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
functools
from
typing
import
Callable
from
unittest.mock
import
patch
import
pytest
import
torch
import
torch_xla.distributed.xla_multiprocessing
as
xmp
from
typing_extensions
import
ParamSpec
from
vllm.distributed.communication_op
import
(
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_reduce
)
from
vllm.distributed.parallel_state
import
(
ensure_model_parallel_initialized
,
init_distributed_environment
)
from
vllm.utils
import
get_distributed_init_method
,
get_open_port
_P
=
ParamSpec
(
"_P"
)
def
reinitialize_neuron_runtime
(
f
:
Callable
[
_P
,
None
])
->
Callable
[
_P
,
None
]:
"""Decorator to reinitialize the Neuron Runtime before executing a test.
This is necessary for distributed tests which need to reallocate Neuron
Cores to separate subprocesses.
"""
@
functools
.
wraps
(
f
)
def
wrapper
(
*
args
:
_P
.
args
,
**
kwargs
:
_P
.
kwargs
)
->
None
:
runtime
=
torch
.
classes
.
neuron
.
Runtime
()
runtime
.
initialize
()
runtime
.
unsafe_close
()
f
(
*
args
,
**
kwargs
)
runtime
.
initialize
()
return
wrapper
def
all_gather_test_worker
(
index
,
tp_degree
,
distributed_init_method
):
init_distributed_environment
(
tp_degree
,
index
,
distributed_init_method
,
index
,
backend
=
"xla"
)
ensure_model_parallel_initialized
(
tp_degree
,
1
)
num_dimensions
=
3
tensor_size
=
list
(
range
(
2
,
num_dimensions
+
2
))
total_size
=
1
for
s
in
tensor_size
:
total_size
*=
s
all_gather_dimension
=
-
1
all_tensors
=
[
torch
.
arange
(
total_size
,
dtype
=
torch
.
float32
,
device
=
"xla"
).
reshape
(
tensor_size
)
*
(
r
+
1
)
for
r
in
range
(
tp_degree
)
]
expected
=
torch
.
cat
(
all_tensors
,
dim
=
all_gather_dimension
)
t
=
all_tensors
[
index
%
tp_degree
]
t
=
tensor_model_parallel_all_gather
(
t
,
all_gather_dimension
)
torch
.
testing
.
assert_close
(
t
,
expected
)
def
all_reduce_test_worker
(
index
,
tp_degree
,
distributed_init_method
):
init_distributed_environment
(
tp_degree
,
index
,
distributed_init_method
,
index
,
backend
=
"xla"
)
ensure_model_parallel_initialized
(
tp_degree
,
1
)
num_elements
=
8
all_tensors
=
[
torch
.
arange
(
num_elements
,
dtype
=
torch
.
float32
,
device
=
"xla"
)
*
(
r
+
1
)
for
r
in
range
(
tp_degree
)
]
expected
=
torch
.
sum
(
torch
.
stack
(
all_tensors
,
dim
=
0
),
dim
=
0
)
t
=
all_tensors
[
index
%
tp_degree
]
t
=
tensor_model_parallel_all_reduce
(
t
)
torch
.
testing
.
assert_close
(
t
,
expected
)
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"test_target"
,
[
all_reduce_test_worker
,
all_gather_test_worker
])
@
reinitialize_neuron_runtime
def
test_neuron_multi_process_tensor_parallel
(
monkeypatch
,
tp_size
,
test_target
):
with
patch
(
'torch_xla._XLAC._xla_runtime_is_initialized'
,
return_value
=
False
):
distributed_init_method
=
get_distributed_init_method
(
"127.0.0.1"
,
get_open_port
())
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
monkeypatch
.
setenv
(
"NEURONCORE_NUM_DEVICES"
,
str
(
tp_size
))
monkeypatch
.
setenv
(
"NEURON_PJRT_PROCESSES_NUM_DEVICES"
,
','
.
join
([
'1'
for
_
in
range
(
tp_size
)]))
xmp
.
spawn
(
test_target
,
args
=
(
tp_size
,
distributed_init_method
))
tests/neuron/2_core/test_eagle.py
deleted
100644 → 0
View file @
848562bd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
json
import
os
import
shutil
import
tempfile
import
torch
from
huggingface_hub
import
snapshot_download
from
safetensors
import
safe_open
from
vllm
import
LLM
,
SamplingParams
def
patch_eagle_draft_with_lm_head
(
target_model_id
:
str
,
draft_model_id
:
str
)
->
str
:
# In NxDI, draft model checkpoint must include lm_head weights from target
# model. For more details see https://awsdocs-neuron.readthedocs-hosted.com
# /en/latest/libraries/nxd-inference/developer_guides/feature-guide.html
# #eagle-checkpoint-compatibility
final_draft_dir
=
"/tmp/patched_eagle_draft"
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
target_dir
=
snapshot_download
(
repo_id
=
target_model_id
,
local_dir
=
os
.
path
.
join
(
tmp_dir
,
"target"
))
draft_dir
=
snapshot_download
(
repo_id
=
draft_model_id
,
local_dir
=
os
.
path
.
join
(
tmp_dir
,
"draft"
))
lm_head_key
=
"lm_head.weight"
index_path
=
os
.
path
.
join
(
target_dir
,
"model.safetensors.index.json"
)
with
open
(
index_path
)
as
f
:
index
=
json
.
load
(
f
)
shard_name
=
index
[
"weight_map"
][
lm_head_key
]
target_safetensor_path
=
os
.
path
.
join
(
target_dir
,
shard_name
)
with
safe_open
(
target_safetensor_path
,
framework
=
"pt"
)
as
f
:
target_lm_head
=
f
.
get_tensor
(
lm_head_key
)
draft_path
=
os
.
path
.
join
(
draft_dir
,
"pytorch_model.bin"
)
draft_state_dict
=
torch
.
load
(
draft_path
,
map_location
=
"cpu"
)
draft_state_dict
[
lm_head_key
]
=
target_lm_head
.
to
(
torch
.
float16
)
torch
.
save
(
draft_state_dict
,
draft_path
)
shutil
.
copytree
(
draft_dir
,
final_draft_dir
,
dirs_exist_ok
=
True
)
return
final_draft_dir
def
test_eagle
():
patched_draft_path
=
patch_eagle_draft_with_lm_head
(
target_model_id
=
"meta-llama/Llama-2-7b-hf"
,
draft_model_id
=
"yuhuili/EAGLE-llama2-chat-7B"
)
llm
=
LLM
(
model
=
"meta-llama/Llama-2-7b-hf"
,
speculative_config
=
{
"model"
:
patched_draft_path
,
"num_speculative_tokens"
:
5
,
"max_model_len"
:
128
},
max_num_seqs
=
1
,
max_model_len
=
128
,
tensor_parallel_size
=
2
,
override_neuron_config
=
{
"enable_eagle_speculation"
:
True
,
"enable_fused_speculation"
:
True
,
"fused_qkv"
:
True
},
)
prompts
=
[
"The president of the United States is"
,
]
outputs
=
llm
.
generate
(
prompts
,
SamplingParams
(
top_k
=
1
))
expected_output
=
" the head of state and head of government of "
\
"the United States. The president direct"
for
output
in
outputs
:
generated_text
=
output
.
outputs
[
0
].
text
print
(
f
"Prompt:
{
output
.
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
assert
(
expected_output
==
generated_text
)
print
(
"Neuron Eagle speculation test passed."
)
tests/neuron/2_core/test_mistral.py
deleted
100644 → 0
View file @
848562bd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
vllm
import
LLM
,
SamplingParams
def
test_mistral
():
llm
=
LLM
(
model
=
"mistralai/Mistral-7B-v0.1"
,
tensor_parallel_size
=
2
,
max_num_seqs
=
4
,
max_model_len
=
128
,
override_neuron_config
=
{
"sequence_parallel_enabled"
:
False
,
"skip_warmup"
:
True
})
# Send more prompts than the compiled batch size (4) and request
# varying generation lengths to test accuracy related to Neuron
# specific sequence id sorting.
prompts
=
[
"The president of the United States is"
,
"The capital of France is"
,
"What is Annapurna labs?"
,
"I believe the meaning of life is"
,
"Tell me a story about a brave knight"
,
"Hello, my name is Llama"
,
]
sampling_params
=
[
SamplingParams
(
top_k
=
1
,
max_tokens
=
10
),
SamplingParams
(
top_k
=
1
,
max_tokens
=
20
),
SamplingParams
(
top_k
=
1
,
max_tokens
=
30
),
SamplingParams
(
top_k
=
1
,
max_tokens
=
40
),
SamplingParams
(
top_k
=
1
,
max_tokens
=
50
),
SamplingParams
(
top_k
=
1
,
max_tokens
=
60
)
]
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
expected_outputs
=
[
" the most powerful person in the world. He is"
,
" a city of many faces. It is a city of history, culture, art, "
"fashion, and"
,
"
\n\n
Annapurna Labs is a semiconductor company that was founded "
"in 2013 by Amazon. The company is"
,
" to be happy.
\n\n
I believe that happiness is a choice.
\n\n
I "
"believe that happiness is a state of mind.
\n\n
I believe that "
"happiness is a journey.
\n\n
I believe"
,
" who rescued a princess from a dragon.
\n\n
Tell me a story about"
" a princess who rescued herself from a dragon.
\n\n
Tell me a "
"story about a princess who rescued herself from a dragon and "
"then rescued a knight from"
,
" and I am a 10 year old male. I am a very friendly and "
"affectionate boy who loves to be around people. I am a very "
"active boy who loves to play and run around. I am a very smart "
"boy who loves to learn new things. I am a very loyal boy"
]
for
expected_output
,
output
in
zip
(
expected_outputs
,
outputs
):
generated_text
=
output
.
outputs
[
0
].
text
print
(
f
"Prompt:
{
output
.
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
assert
(
expected_output
==
generated_text
)
print
(
"Neuron Mistral test passed."
)
tests/neuron/2_core/test_multi_lora.py
deleted
100644 → 0
View file @
848562bd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
huggingface_hub
import
snapshot_download
from
vllm
import
LLM
,
SamplingParams
from
vllm.lora.request
import
LoRARequest
def
test_llama_single_lora
():
sql_lora_files
=
snapshot_download
(
repo_id
=
"yard1/llama-2-7b-sql-lora-test"
)
llm
=
LLM
(
model
=
"meta-llama/Llama-2-7b-hf"
,
tensor_parallel_size
=
2
,
max_num_seqs
=
4
,
max_model_len
=
512
,
override_neuron_config
=
{
"sequence_parallel_enabled"
:
False
,
"skip_warmup"
:
True
,
"lora_modules"
:
[{
"name"
:
"lora_id_1"
,
"path"
:
sql_lora_files
}]
},
enable_lora
=
True
,
max_loras
=
1
,
max_lora_rank
=
256
,
device
=
"neuron"
)
"""For multi-lora requests using NxDI as the backend, only the lora_name
needs to be specified. The lora_id and lora_path are supplied at the LLM
class/server initialization, after which the paths are handled by NxDI"""
lora_req_1
=
LoRARequest
(
"lora_id_1"
,
0
,
" "
)
prompts
=
[
"The president of the United States is"
,
"The capital of France is"
,
]
outputs
=
llm
.
generate
(
prompts
,
SamplingParams
(
top_k
=
1
),
lora_request
=
[
lora_req_1
,
lora_req_1
])
expected_outputs
=
[
" the head of state and head of government of the United States. "
"The president direct"
,
" a city of contrasts. The city is home to the Eiffel Tower"
]
for
expected_output
,
output
in
zip
(
expected_outputs
,
outputs
):
generated_text
=
output
.
outputs
[
0
].
text
assert
(
expected_output
==
generated_text
)
def
test_llama_multiple_lora
():
sql_lora_files
=
snapshot_download
(
repo_id
=
"yard1/llama-2-7b-sql-lora-test"
)
llm
=
LLM
(
model
=
"meta-llama/Llama-2-7b-hf"
,
tensor_parallel_size
=
2
,
max_num_seqs
=
4
,
max_model_len
=
512
,
override_neuron_config
=
{
"sequence_parallel_enabled"
:
False
,
"skip_warmup"
:
True
,
"lora_modules"
:
[{
"name"
:
"lora_id_1"
,
"path"
:
sql_lora_files
},
{
"name"
:
"lora_id_2"
,
"path"
:
sql_lora_files
}]
},
enable_lora
=
True
,
max_loras
=
2
,
max_lora_rank
=
256
,
device
=
"neuron"
)
"""For multi-lora requests using NxDI as the backend, only the lora_name
needs to be specified. The lora_id and lora_path are supplied at the LLM
class/server initialization, after which the paths are handled by NxDI"""
lora_req_1
=
LoRARequest
(
"lora_id_1"
,
0
,
" "
)
lora_req_2
=
LoRARequest
(
"lora_id_2"
,
1
,
" "
)
prompts
=
[
"The president of the United States is"
,
"The capital of France is"
,
]
outputs
=
llm
.
generate
(
prompts
,
SamplingParams
(
top_k
=
1
),
lora_request
=
[
lora_req_1
,
lora_req_2
])
expected_outputs
=
[
" the head of state and head of government of the United States. "
"The president direct"
,
" a city of contrasts. The city is home to the Eiffel Tower"
]
for
expected_output
,
output
in
zip
(
expected_outputs
,
outputs
):
generated_text
=
output
.
outputs
[
0
].
text
assert
(
expected_output
==
generated_text
)
vllm/attention/ops/nki_flash_attn.py
deleted
100644 → 0
View file @
848562bd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
neuronxcc.nki.isa
as
nisa
import
neuronxcc.nki.language
as
nl
import
numpy
as
np
import
torch
from
neuronxcc
import
nki
from
neuronxcc.nki.language
import
par_dim
from
vllm.utils
import
cdiv
def
is_power_of_2
(
x
):
return
x
>
0
and
(
x
&
(
x
-
1
))
==
0
@
nki
.
jit
def
load_block_tables
(
block_tables_hbm
,
num_tiles
,
num_blocks_per_tile
):
"""
Load block tables from HBM into SRAM
`block_tables_hbm` has shape `(num_tiles * num_blocks_per_tile, )`.
In case `num_tiles > B_P_SIZE`, we need further tile `num_tile` dimension.
"""
B_P_SIZE
=
128
# reshape as `(num_tiles, num_blocks_per_tile)`
assert
len
(
block_tables_hbm
.
shape
)
==
1
(
num_total_blocks
,
)
=
block_tables_hbm
.
shape
assert
num_blocks_per_tile
*
num_tiles
==
num_total_blocks
block_tables_hbm
=
block_tables_hbm
.
reshape
(
(
num_tiles
,
num_blocks_per_tile
))
block_tables_sbuf
=
nl
.
zeros
(
(
cdiv
(
num_tiles
,
B_P_SIZE
),
par_dim
(
B_P_SIZE
),
num_blocks_per_tile
),
dtype
=
nl
.
int32
,
)
for
i
in
nl
.
affine_range
(
cdiv
(
num_tiles
,
B_P_SIZE
)):
i_p
=
nl
.
arange
(
B_P_SIZE
)[:,
None
]
i_f
=
nl
.
arange
(
num_blocks_per_tile
)[
None
,
:]
block_tables_sbuf
[
i
,
i_p
,
i_f
]
=
nl
.
load
(
block_tables_hbm
[
i_p
+
i
*
B_P_SIZE
,
i_f
],
dtype
=
nl
.
int32
,
mask
=
(
i_p
+
i
*
B_P_SIZE
<
num_tiles
),
)
return
block_tables_sbuf
@
nki
.
jit
def
transform_block_tables_for_indirect_load
(
block_tables
,
block_size_tiling_factor
,
num_head
,
head_id
,
):
"""
This function does two things:
1. calculate new `block_tables` for a `head_id` after flattening
`num_block`, `num_head`, and `block_size_tiling_factor` dimensions
2. transpose the result so that `block_table` for each tile is mapped to
SBUF Partition dimension for vectorized DMA
Tiling trick to further improve DMA performance:
Given KV cache shape `(num_block, num_head, block_size, D)`, when loading M
blocks of a given `head_id` from HBM, the load `cache[block_tables,
head_id]` has shape `(M, block_size, D)`. If M < B_P_SIZE = 128, DMA may not
fully utilize hardware parallelization. The solution is to tile `block_size`
into `(block_size_tiling_factor, tiled_block_size)` s.t. `M *
block_size_tiling_factor = B_P_SIZE`. After tiling, KV cache has shape
`(num_block, num_head, block_size_tiling_factor, tiled_block_size, D)`.
Note:
We don't further tile D dimension as small DMA size also hurts performance.
"""
B_P_SIZE
=
128
num_partitions
,
num_tiles_per_partition
,
num_blocks_per_tile
=
(
block_tables
.
shape
)
assert
num_tiles_per_partition
==
B_P_SIZE
assert
is_power_of_2
(
num_blocks_per_tile
),
f
"
{
num_blocks_per_tile
=
}
is not power of 2"
num_loads
=
cdiv
(
num_blocks_per_tile
,
B_P_SIZE
)
block_tables_transposed
=
nl
.
ndarray
(
(
num_loads
,
par_dim
(
B_P_SIZE
),
num_partitions
*
num_tiles_per_partition
,
),
dtype
=
nl
.
int32
,
)
# prepare iota ahead of time to avoid repeatedly using Gpsimd
if
num_head
>
1
:
head_id
=
nisa
.
iota
(
head_id
,
dtype
=
nl
.
int32
).
reshape
((
1
,
1
))
head_id
=
nl
.
transpose
(
head_id
.
broadcast_to
((
1
,
num_tiles_per_partition
)))
if
num_blocks_per_tile
>
1
:
head_id
=
head_id
.
broadcast_to
(
(
num_tiles_per_partition
,
num_blocks_per_tile
))
if
block_size_tiling_factor
>
1
:
broadcast_shape
=
(
num_tiles_per_partition
,
num_blocks_per_tile
,
block_size_tiling_factor
,
)
offset
=
nisa
.
iota
(
nl
.
arange
(
block_size_tiling_factor
)[
None
,
None
,
:],
dtype
=
nl
.
int32
).
broadcast_to
(
broadcast_shape
)
for
partition_id
in
nl
.
affine_range
(
num_partitions
):
block_tables_partition
=
block_tables
[
partition_id
]
if
num_head
>
1
:
# fuse num_block and num_head dimension
block_tables_partition
=
block_tables_partition
*
num_head
+
head_id
if
block_size_tiling_factor
>
1
:
# need to apply block size tiling trick
assert
num_blocks_per_tile
*
block_size_tiling_factor
==
B_P_SIZE
block_tables_partition
=
((
block_tables_partition
*
block_size_tiling_factor
).
reshape
(
(
num_tiles_per_partition
,
num_blocks_per_tile
,
1
)).
broadcast_to
(
broadcast_shape
))
new_block_tables
=
block_tables_partition
+
offset
new_block_tables
=
new_block_tables
.
reshape
(
(
num_tiles_per_partition
,
B_P_SIZE
))
else
:
new_block_tables
=
block_tables_partition
# transpose the block table so that it can be used by vector DGE
for
i
in
nl
.
affine_range
(
num_loads
):
i_p
=
nl
.
arange
(
B_P_SIZE
)[:,
None
]
i_f
=
(
partition_id
*
num_tiles_per_partition
+
nl
.
arange
(
num_tiles_per_partition
)[
None
,
:])
block_tables_transposed
[
i
,
i_p
,
i_f
]
=
nl
.
transpose
(
new_block_tables
[:,
nl
.
ds
(
i
*
B_P_SIZE
,
B_P_SIZE
)])
return
block_tables_transposed
@
nki
.
jit
def
load_kv_tile_from_cache
(
cur_k_tile
,
cur_v_tile
,
kv_cache
,
block_tables
,
large_k_tile_idx
,
num_blocks_per_large_tile
,
tiled_block_size
,
B_P_SIZE
,
B_D_SIZE
,
):
"""
Load KV cache and transform Key and Value into layout required by Matmul
Vectorized DMA Load layout:
Key and Value: (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE)
Layout used by attention matmuls:
Key: (par_dim(B_D_SIZE), seqlen_kv)
Value: (seqlen_kv // B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE)
equivalent to (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE)
"""
# load key cache
num_loads
=
cdiv
(
num_blocks_per_large_tile
,
B_P_SIZE
)
for
load_idx
in
nl
.
affine_range
(
num_loads
):
i_p
=
nl
.
arange
(
B_P_SIZE
)[:,
None
]
i_f
=
nl
.
arange
(
tiled_block_size
*
B_D_SIZE
)[
None
,
:]
loaded
=
nl
.
load
(
kv_cache
[
0
,
block_tables
[
load_idx
,
i_p
,
large_k_tile_idx
],
i_f
])
if
cur_k_tile
.
dtype
!=
loaded
.
dtype
:
loaded
=
nl
.
copy
(
loaded
,
dtype
=
cur_k_tile
.
dtype
)
# Transpose SBUF tensor using PE
for
tb_i
in
nl
.
affine_range
(
tiled_block_size
):
cur_k_tile
[
:,
nl
.
ds
(
load_idx
*
B_P_SIZE
*
tiled_block_size
+
tb_i
*
B_P_SIZE
,
B_P_SIZE
,
),
]
=
nl
.
transpose
(
loaded
[:,
nl
.
ds
(
tb_i
*
B_D_SIZE
,
B_D_SIZE
)])
# load value cache
for
load_idx
in
nl
.
affine_range
(
num_loads
):
loaded
=
nl
.
load
(
kv_cache
[
1
,
block_tables
[
load_idx
,
i_p
,
large_k_tile_idx
],
i_f
])
if
cur_v_tile
.
dtype
!=
loaded
.
dtype
:
loaded
=
nl
.
copy
(
loaded
,
dtype
=
cur_v_tile
.
dtype
)
i_p
=
nl
.
arange
(
B_P_SIZE
)[:,
None
]
i_f
=
nl
.
arange
(
tiled_block_size
*
B_D_SIZE
)[
None
,
:]
cur_v_tile
[
:,
nl
.
ds
(
load_idx
*
tiled_block_size
*
B_D_SIZE
,
tiled_block_size
*
B_D_SIZE
,
),
]
=
loaded
@
nki
.
jit
def
transpose_p_local
(
p_local_transposed
,
p_local
,
LARGE_TILE_SZ
,
B_F_SIZE
=
512
):
for
i
in
nl
.
affine_range
(
LARGE_TILE_SZ
//
B_F_SIZE
):
if
nisa
.
get_nc_version
()
==
nisa
.
nc_version
.
gen3
:
p_local_t_tmp
=
nl
.
ndarray
((
par_dim
(
128
),
B_F_SIZE
),
buffer
=
nl
.
sbuf
,
dtype
=
p_local
.
dtype
)
else
:
p_local_t_tmp
=
nl
.
ndarray
((
par_dim
(
128
),
B_F_SIZE
),
buffer
=
nl
.
psum
,
dtype
=
np
.
float32
)
for
j
in
nl
.
affine_range
(
B_F_SIZE
//
128
):
j_128_slice
=
nl
.
ds
(
j
*
128
,
128
)
i_j_128_slice
=
nl
.
ds
(
i
*
B_F_SIZE
+
j
*
128
,
128
)
if
nisa
.
get_nc_version
()
==
nisa
.
nc_version
.
gen3
:
p_local_t_tmp
[:,
j_128_slice
]
=
nisa
.
dma_transpose
(
p_local
[:,
i_j_128_slice
])
else
:
p_local_t_tmp
[:,
j_128_slice
]
=
nisa
.
nc_transpose
(
p_local
[:,
i_j_128_slice
])
p_local_transposed
[:,
nl
.
ds
(
i
*
B_F_SIZE
,
B_F_SIZE
)]
=
nl
.
copy
(
p_local_t_tmp
,
dtype
=
p_local_transposed
.
dtype
)
@
nki
.
jit
def
_flash_attention_core
(
q_local_tile
,
k
,
v
,
o_buffer
,
l_buffer
,
m_buffer
,
kernel_dtype
,
acc_type
,
tile_mask
,
use_causal_mask
,
q_tile_idx
=
None
,
initialize
=
False
,
LARGE_TILE_SZ
=
2048
,
B_P_SIZE
=
128
,
B_F_SIZE
=
512
,
B_D_SIZE
=
128
,
qk_res_buffer
=
None
,
):
"""
The flash attention core function to calculate self attention between a tile
of q and a block of K and V.
The q_local_tile has (B_P_SIZE, B_D_SIZE)
The K and V have shape (B_D_SIZE, LARGE_TILE_SZ), whose free dimension will
be split into size B_F_SIZE tiles
The results are stored in the following three buffers
o_buffer: (B_P_SIZE, d)
l_buffer: (B_P_SIZE, 1)
m_buffer: (B_P_SIZE, 1)
All IO buffers are in SBUF.
"""
num_k_tile_per_large_tile
=
LARGE_TILE_SZ
//
B_F_SIZE
qk_res_buf
=
nl
.
ndarray
((
par_dim
(
B_P_SIZE
),
LARGE_TILE_SZ
),
buffer
=
nl
.
sbuf
,
dtype
=
acc_type
)
max_local
=
nl
.
ndarray
((
par_dim
(
B_P_SIZE
),
num_k_tile_per_large_tile
),
dtype
=
acc_type
)
for
k_i
in
nl
.
affine_range
(
num_k_tile_per_large_tile
):
k_i_b_f_slice
=
nl
.
ds
(
k_i
*
B_F_SIZE
,
B_F_SIZE
)
if
use_causal_mask
:
# mask are used to only apply computation to the lower half of the
# matrix, which reduce the arithmetic intensity by up to 50%
multiplication_required_selection
=
(
q_tile_idx
*
B_P_SIZE
>=
k_i
*
B_F_SIZE
)
else
:
multiplication_required_selection
=
True
if
multiplication_required_selection
:
qk_psum
=
nl
.
ndarray
((
par_dim
(
B_P_SIZE
),
B_F_SIZE
),
dtype
=
np
.
float32
,
buffer
=
nl
.
psum
)
# (128, 512)
qk_psum
[:,
:]
=
nl
.
matmul
(
q_local_tile
,
k
[:,
k_i_b_f_slice
],
transpose_x
=
True
)
# (p(128), 512)
qk_res_buf
[:,
k_i_b_f_slice
]
=
nl
.
where
(
tile_mask
[:,
k_i_b_f_slice
],
qk_psum
[:,
nl
.
ds
(
0
,
B_F_SIZE
)],
-
9984.0
,
dtype
=
acc_type
,
)
else
:
qk_res_buf
[:,
k_i_b_f_slice
]
=
-
9984.0
# Calculate max of the current tile
max_local
[:,
k_i
]
=
nisa
.
tensor_reduce
(
np
.
max
,
qk_res_buf
[:,
k_i_b_f_slice
],
axis
=
(
1
,
),
dtype
=
acc_type
,
negate
=
False
,
)
if
qk_res_buffer
is
not
None
:
qk_res_buffer
[:,
:]
=
nl
.
copy
(
qk_res_buf
[:,
:])
max_
=
nisa
.
tensor_reduce
(
np
.
max
,
max_local
[:,
:],
axis
=
(
1
,
),
dtype
=
acc_type
,
negate
=
False
,
)
o_previous_scaled
=
nl
.
ndarray
((
par_dim
(
B_P_SIZE
),
B_D_SIZE
),
dtype
=
o_buffer
.
dtype
)
if
initialize
:
m_buffer
[:,
0
]
=
nl
.
copy
(
max_
)
m_current
=
max_
else
:
m_previous
=
nl
.
copy
(
m_buffer
[:,
0
])
m_buffer
[:,
0
]
=
nl
.
maximum
(
m_previous
,
max_
)
# (128,1)
m_current
=
m_buffer
[:,
0
]
# Compute scaling factor
alpha
=
nisa
.
activation
(
np
.
exp
,
m_previous
,
bias
=-
1
*
m_current
,
scale
=
1.0
,
)
o_previous_scaled
[...]
=
nl
.
multiply
(
o_buffer
[:,
:],
alpha
)
p_local
=
nl
.
ndarray
((
par_dim
(
B_P_SIZE
),
LARGE_TILE_SZ
),
dtype
=
kernel_dtype
)
REDUCTION_TILE
=
min
(
2048
,
LARGE_TILE_SZ
//
2
)
p_partial_sum
=
nl
.
ndarray
(
(
par_dim
(
B_P_SIZE
),
LARGE_TILE_SZ
//
REDUCTION_TILE
),
dtype
=
acc_type
,
)
for
k_r_i
in
nl
.
affine_range
(
LARGE_TILE_SZ
//
REDUCTION_TILE
):
k_r_i_reduce_slice
=
nl
.
ds
(
k_r_i
*
REDUCTION_TILE
,
REDUCTION_TILE
)
# compute exp(qk - max)
# Compute partial row - tile sum of exp(qk - max))
# FIXME : Use activation accumulate to accumulate over k_r_i loop ?
p_local
[:,
k_r_i_reduce_slice
]
=
nisa
.
activation_reduce
(
np
.
exp
,
qk_res_buf
[:,
k_r_i_reduce_slice
],
bias
=-
1
*
m_current
,
scale
=
1.0
,
reduce_op
=
nl
.
add
,
reduce_res
=
p_partial_sum
[:,
k_r_i
],
dtype
=
kernel_dtype
,
)
ps
=
nl
.
sum
(
p_partial_sum
,
axis
=
1
,
dtype
=
acc_type
)
p_local_transposed
=
nl
.
ndarray
((
par_dim
(
B_P_SIZE
),
LARGE_TILE_SZ
),
dtype
=
kernel_dtype
)
transpose_p_local
(
p_local_transposed
=
p_local_transposed
,
p_local
=
p_local
,
LARGE_TILE_SZ
=
LARGE_TILE_SZ
,
B_F_SIZE
=
B_F_SIZE
,
)
pv_psum
=
nl
.
zeros
(
(
par_dim
(
B_P_SIZE
),
B_D_SIZE
),
dtype
=
np
.
float32
,
buffer
=
nl
.
psum
,
)
for
k_i
in
nl
.
affine_range
(
LARGE_TILE_SZ
//
B_P_SIZE
):
pv_psum
[:,
:]
+=
nl
.
matmul
(
p_local_transposed
[:,
nl
.
ds
(
k_i
*
B_P_SIZE
,
B_P_SIZE
)],
v
[:,
nl
.
ds
(
k_i
*
B_D_SIZE
,
B_D_SIZE
)],
transpose_x
=
True
,
)
# (128, 128) (p(Br), d)
if
initialize
:
o_buffer
[:,
:]
=
nl
.
copy
(
pv_psum
[:,
:])
l_buffer
[:,
0
]
=
nl
.
add
(
nl
.
log
(
ps
),
max_
)
else
:
o_buffer
[:,
:]
=
nl
.
add
(
o_previous_scaled
,
pv_psum
)
l_prev
=
l_buffer
[:,
0
]
l_exp
=
nl
.
add
(
nl
.
exp
(
nl
.
subtract
(
l_prev
,
m_current
)),
ps
,
)
l_buffer
[:,
0
]
=
nl
.
add
(
m_current
,
nl
.
log
(
l_exp
))
@
nki
.
jit
def
load_v_tile
(
v_hbm_tile
,
cur_v_tile
,
large_tile_idx
,
v_i
,
LARGE_TILE_SZ
):
B_P_SIZE
=
128
B_D_SIZE
=
v_hbm_tile
.
shape
[
-
1
]
loaded
=
nl
.
load
(
v_hbm_tile
[
nl
.
ds
(
large_tile_idx
*
LARGE_TILE_SZ
+
B_P_SIZE
*
v_i
,
B_P_SIZE
),
:,
])
if
cur_v_tile
.
dtype
!=
loaded
.
dtype
:
loaded
=
nl
.
copy
(
loaded
,
dtype
=
cur_v_tile
.
dtype
)
cur_v_tile
[:,
nl
.
ds
(
v_i
*
B_D_SIZE
,
B_D_SIZE
)]
=
loaded
@
nki
.
jit
def
flash_paged_attention
(
query
,
key
,
value
,
kv_cache
,
block_tables
,
mask
,
softmax_scale
=
None
,
mixed_precision
=
True
,
LARGE_TILE_SZ
=
2048
,
return_debug_tensors
=
False
,
):
"""
Flash PagedAttention Forward Kernel.
IO tensor layouts:
- query: shape (1, n_heads, d, seq_q)
- key: shape (1, n_kv_heads, d, seq_k)
- value: shape (1, n_kv_heads, seq_v, d)
- kv_cache: (2, num_blocks, n_kv_heads, block_size, d)
- block_tables: (num_active_blocks, )
- mask: (seq_q, num_active_blocks * block_size + seq_q)
- o: shape (1, n_heads, seq_q, d)
- This kernel requires seq_k == seq_v
- We use continuous batching by default, so the batch dimension is
always 1, and different requests are concatenated along sequence
dimension.
- We use paged cache blocks (kv_cache) to store KV cache.
IO tensor dtypes:
- This kernel assumes all IO tensors have the same dtype except for
block_tables (int32) and mask (int32)
- If mixed_precision is True, then all Tensor Engine operation will be
performed in bfloat16 and accumulation will be performed in float32.
Otherwise the intermediates will be in the same type as the inputs.
Compile-time Constants:
- softmax_scale: scaling for softmax, is None, default is `1.0/(d**0.5)`
- mixed_precision: flag to set non-matmul ops in fp32 precision, default
is set to `true`, if false, we use same precision as input types
- LARGE_TILE_SZ: `default=2048`, size of the kv tile size for attention
computation reduction
GQA support Notes:
the spmd kernel for launching kernel should be on kv_heads instead of
nheads
Example usage:
MHA: q: [b, h, d, s], k: [b, h, d, s], v: [b, h, s, d]
usage: `flash_fwd[b, h](q, k, v, ...)`
GQA: q: [b, h, d, s], k: [b, kv_h, d, s], v: [b, kv_h, s, d]
usage: `flash_fwd[b, kv_h](q, k, v, ...)`
"""
B_F_SIZE
=
512
B_P_SIZE
=
128
b
,
h
,
d
,
seqlen_q
=
query
.
shape
B_D_SIZE
=
d
n_tile_q
=
seqlen_q
//
B_P_SIZE
# since q will be loaded on tensor engine
_
,
num_blocks
,
k_h
,
block_size
,
_
=
kv_cache
.
shape
q_h_per_k_h
=
h
//
k_h
assert
b
==
1
,
f
"invalid batch size
{
b
=
}
"
assert
d
<=
128
,
f
" we do not support head_dim > 128, got head dim
{
d
=
}
"
cache_shape
=
(
2
,
num_blocks
,
k_h
,
block_size
,
d
)
assert
(
tuple
(
kv_cache
.
shape
)
==
cache_shape
),
f
"
{
kv_cache
.
shape
=
}
mismatch, expect
{
cache_shape
}
"
assert
key
is
None
or
tuple
(
key
.
shape
)
==
(
1
,
k_h
,
d
,
seqlen_q
,
),
f
"key shape
{
key
.
shape
}
mismatch!"
assert
value
is
None
or
tuple
(
value
.
shape
)
==
(
1
,
k_h
,
seqlen_q
,
d
,
),
f
"value shape
{
value
.
shape
}
mismatch!"
assert
(
nl
.
program_ndim
()
==
2
),
f
"Expect spmd grid with 2 dimensions, got
{
nl
.
program_ndim
()
}
instead!"
batch_id
=
nl
.
program_id
(
axis
=
0
)
head_id
=
nl
.
program_id
(
axis
=
1
)
(
num_active_blocks
,
)
=
block_tables
.
shape
context_kv_len
=
num_active_blocks
*
block_size
assert
(
LARGE_TILE_SZ
%
B_F_SIZE
==
0
),
f
"Need
{
LARGE_TILE_SZ
=
}
to be divisible by
{
B_F_SIZE
=
}
in transpose_p"
assert
(
context_kv_len
%
LARGE_TILE_SZ
==
0
),
f
"Need
{
context_kv_len
=
}
to be divisible by
{
LARGE_TILE_SZ
=
}
"
num_blocks_per_large_tile
=
LARGE_TILE_SZ
//
block_size
assert
is_power_of_2
(
num_blocks_per_large_tile
),
f
"
{
num_blocks_per_large_tile
=
}
is expected of be power of 2"
if
seqlen_q
>
B_F_SIZE
:
MAX_REDUCTION_TILE
=
2048
if
seqlen_q
//
2
>
MAX_REDUCTION_TILE
:
assert
(
seqlen_q
%
MAX_REDUCTION_TILE
==
0
),
f
"
{
seqlen_q
=
}
should be divisible by
{
MAX_REDUCTION_TILE
=
}
"
else
:
assert
(
seqlen_q
%
B_F_SIZE
==
0
),
f
"
{
seqlen_q
=
}
should be divisible by
{
B_F_SIZE
=
}
)"
kernel_dtype
=
nl
.
bfloat16
if
mixed_precision
else
query
.
dtype
acc_type
=
np
.
dtype
(
np
.
float32
)
if
mixed_precision
else
kernel_dtype
softmax_scale
=
softmax_scale
or
(
1.0
/
(
d
**
0.5
))
num_large_k_tile
=
context_kv_len
//
LARGE_TILE_SZ
o
=
nl
.
ndarray
((
b
,
h
,
seqlen_q
,
d
),
dtype
=
query
.
dtype
,
buffer
=
nl
.
shared_hbm
)
hbm_l_buffer
,
hbm_m_buffer
,
hbm_qk_res
,
qk_res_buffer
=
(
None
,
None
,
None
,
None
,
)
if
return_debug_tensors
:
hbm_l_buffer
=
nl
.
ndarray
((
b
,
h
,
seqlen_q
),
dtype
=
acc_type
,
buffer
=
nl
.
shared_hbm
)
hbm_m_buffer
=
nl
.
ndarray
((
b
,
h
,
seqlen_q
),
dtype
=
acc_type
,
buffer
=
nl
.
shared_hbm
)
hbm_qk_res
=
nl
.
ndarray
((
b
,
h
,
B_P_SIZE
,
seqlen_q
),
dtype
=
acc_type
,
buffer
=
nl
.
shared_hbm
)
qk_res_buffer
=
nl
.
zeros
(
(
n_tile_q
,
q_h_per_k_h
,
par_dim
(
B_P_SIZE
),
seqlen_q
),
dtype
=
acc_type
,
buffer
=
nl
.
sbuf
,
lazy_initialization
=
True
,
)
block_tables_sbuf
=
load_block_tables
(
block_tables_hbm
=
block_tables
,
num_tiles
=
num_large_k_tile
,
num_blocks_per_tile
=
num_blocks_per_large_tile
,
)
# On Neuron, we need B_P_SIZE = 128 blocks to make DMA efficient
if
num_blocks_per_large_tile
<
B_P_SIZE
:
# we checked num_blocks_per_tile is a power of 2
assert
B_P_SIZE
%
num_blocks_per_large_tile
==
0
block_size_tiling_factor
=
B_P_SIZE
//
num_blocks_per_large_tile
# We assume block_size >= block_size_tiling_factor
assert
block_size
%
block_size_tiling_factor
==
0
else
:
block_size_tiling_factor
=
1
tiled_block_size
=
block_size
//
block_size_tiling_factor
# Indirect DMA load must be placed along Partition Dimension
block_tables_sbuf
=
transform_block_tables_for_indirect_load
(
block_tables_sbuf
,
block_size_tiling_factor
=
block_size_tiling_factor
,
num_head
=
k_h
,
head_id
=
head_id
,
)
# Flatten KV cache to be 3D for loading into SBUF
new_cache_shape
=
(
2
,
num_blocks
*
k_h
*
block_size_tiling_factor
,
tiled_block_size
*
d
,
)
kv_cache
=
kv_cache
.
reshape
(
new_cache_shape
)
# Global Flash Attention accumulators
o_buffer
=
nl
.
zeros
(
(
n_tile_q
,
q_h_per_k_h
,
par_dim
(
B_P_SIZE
),
d
),
dtype
=
acc_type
,
buffer
=
nl
.
sbuf
,
lazy_initialization
=
True
,
)
l_buffer
=
nl
.
zeros
(
(
n_tile_q
,
q_h_per_k_h
,
par_dim
(
B_P_SIZE
),
1
),
dtype
=
acc_type
,
buffer
=
nl
.
sbuf
,
lazy_initialization
=
True
,
)
m_buffer
=
nl
.
zeros
(
(
n_tile_q
,
q_h_per_k_h
,
par_dim
(
B_P_SIZE
),
1
),
dtype
=
acc_type
,
buffer
=
nl
.
sbuf
,
lazy_initialization
=
True
,
)
for
large_k_tile_idx
in
nl
.
sequential_range
(
0
,
num_large_k_tile
):
num_loads
=
cdiv
(
num_blocks_per_large_tile
,
B_P_SIZE
)
cur_k_tile
=
nl
.
ndarray
(
(
par_dim
(
B_D_SIZE
),
LARGE_TILE_SZ
),
dtype
=
kernel_dtype
,
)
cur_v_tile
=
nl
.
ndarray
(
(
par_dim
(
B_P_SIZE
),
num_loads
*
tiled_block_size
*
B_D_SIZE
),
dtype
=
kernel_dtype
,
)
load_kv_tile_from_cache
(
cur_k_tile
=
cur_k_tile
,
cur_v_tile
=
cur_v_tile
,
kv_cache
=
kv_cache
,
block_tables
=
block_tables_sbuf
,
large_k_tile_idx
=
large_k_tile_idx
,
num_blocks_per_large_tile
=
num_blocks_per_large_tile
,
tiled_block_size
=
tiled_block_size
,
B_P_SIZE
=
B_P_SIZE
,
B_D_SIZE
=
B_D_SIZE
,
)
for
i
in
nl
.
affine_range
(
n_tile_q
):
cur_mask
=
nl
.
load
(
mask
[
nl
.
ds
(
i
*
B_P_SIZE
,
B_P_SIZE
),
nl
.
ds
(
large_k_tile_idx
*
LARGE_TILE_SZ
,
LARGE_TILE_SZ
),
])
for
i_q_h
in
nl
.
affine_range
(
q_h_per_k_h
):
q_tile
=
nl
.
ndarray
((
B_D_SIZE
,
B_P_SIZE
),
dtype
=
kernel_dtype
)
q_hbm_tile
=
query
[
batch_id
,
head_id
*
q_h_per_k_h
+
i_q_h
]
q_sbuf_tile
=
nl
.
load
(
q_hbm_tile
[:,
nl
.
ds
(
i
*
B_P_SIZE
,
B_P_SIZE
)])
if
q_sbuf_tile
.
dtype
!=
kernel_dtype
:
q_sbuf_tile
=
nl
.
copy
(
q_sbuf_tile
,
dtype
=
kernel_dtype
)
q_tile
[:,
:]
=
q_sbuf_tile
*
softmax_scale
_flash_attention_core
(
q_local_tile
=
q_tile
,
k
=
cur_k_tile
,
v
=
cur_v_tile
,
o_buffer
=
o_buffer
[
i
,
i_q_h
],
l_buffer
=
l_buffer
[
i
,
i_q_h
],
m_buffer
=
m_buffer
[
i
,
i_q_h
],
kernel_dtype
=
kernel_dtype
,
acc_type
=
acc_type
,
tile_mask
=
cur_mask
,
use_causal_mask
=
False
,
q_tile_idx
=
i
,
initialize
=
large_k_tile_idx
==
0
,
LARGE_TILE_SZ
=
LARGE_TILE_SZ
,
B_P_SIZE
=
B_P_SIZE
,
B_F_SIZE
=
B_F_SIZE
,
B_D_SIZE
=
B_D_SIZE
,
)
# compute attention between input query, key and value
if
key
is
not
None
and
value
is
not
None
:
B_F_SIZE
=
min
(
seqlen_q
,
B_F_SIZE
)
LARGE_TILE_SZ
=
seqlen_q
cur_k_tile
=
nl
.
ndarray
((
par_dim
(
B_D_SIZE
),
LARGE_TILE_SZ
),
dtype
=
kernel_dtype
)
cur_v_tile
=
nl
.
ndarray
(
(
par_dim
(
B_P_SIZE
),
LARGE_TILE_SZ
//
B_P_SIZE
*
B_D_SIZE
),
dtype
=
kernel_dtype
,
)
loaded
=
nl
.
load
(
key
[
batch_id
,
head_id
,
:,
:])
if
loaded
.
dtype
!=
kernel_dtype
:
loaded
=
nl
.
copy
(
loaded
,
dtype
=
kernel_dtype
)
cur_k_tile
[:,
:]
=
loaded
v_hbm_tile
=
value
[
batch_id
,
head_id
]
for
v_i
in
nl
.
affine_range
(
LARGE_TILE_SZ
//
B_P_SIZE
):
load_v_tile
(
v_hbm_tile
=
v_hbm_tile
,
cur_v_tile
=
cur_v_tile
,
large_tile_idx
=
0
,
v_i
=
v_i
,
LARGE_TILE_SZ
=
LARGE_TILE_SZ
,
)
for
i
in
nl
.
affine_range
(
n_tile_q
):
cur_mask
=
nl
.
load
(
mask
[
nl
.
ds
(
i
*
B_P_SIZE
,
B_P_SIZE
),
nl
.
ds
(
context_kv_len
,
LARGE_TILE_SZ
),
])
for
i_q_h
in
nl
.
affine_range
(
q_h_per_k_h
):
q_tile
=
nl
.
ndarray
((
B_D_SIZE
,
B_P_SIZE
),
dtype
=
kernel_dtype
)
q_hbm_tile
=
query
[
batch_id
,
head_id
*
q_h_per_k_h
+
i_q_h
]
q_sbuf_tile
=
nl
.
load
(
q_hbm_tile
[:,
nl
.
ds
(
i
*
B_P_SIZE
,
B_P_SIZE
)])
if
q_sbuf_tile
.
dtype
!=
kernel_dtype
:
q_sbuf_tile
=
nl
.
copy
(
q_sbuf_tile
,
dtype
=
kernel_dtype
)
q_tile
[:,
:]
=
q_sbuf_tile
*
softmax_scale
_flash_attention_core
(
q_local_tile
=
q_tile
,
k
=
cur_k_tile
,
v
=
cur_v_tile
,
o_buffer
=
o_buffer
[
i
,
i_q_h
],
l_buffer
=
l_buffer
[
i
,
i_q_h
],
m_buffer
=
m_buffer
[
i
,
i_q_h
],
kernel_dtype
=
kernel_dtype
,
acc_type
=
acc_type
,
tile_mask
=
cur_mask
,
use_causal_mask
=
True
,
q_tile_idx
=
i
,
initialize
=
False
,
LARGE_TILE_SZ
=
LARGE_TILE_SZ
,
B_P_SIZE
=
B_P_SIZE
,
B_F_SIZE
=
B_F_SIZE
,
B_D_SIZE
=
B_D_SIZE
,
qk_res_buffer
=
(
qk_res_buffer
[
i
,
i_q_h
]
if
qk_res_buffer
is
not
None
else
None
),
)
# -- -- -- -- write output to buffer on HBM -- -- -- -- -- -- #
for
i_q_h
in
nl
.
affine_range
(
q_h_per_k_h
):
for
i
in
nl
.
affine_range
(
n_tile_q
):
out
=
nl
.
multiply
(
o_buffer
[
i
,
i_q_h
],
nl
.
exp
(
m_buffer
[
i
,
i_q_h
]
-
l_buffer
[
i
,
i_q_h
]),
dtype
=
kernel_dtype
,
)
nl
.
store
(
o
[
batch_id
,
head_id
*
q_h_per_k_h
+
i_q_h
,
nl
.
ds
(
i
*
B_P_SIZE
,
B_P_SIZE
),
:,
],
out
,
)
# maximum and summation statistics
if
return_debug_tensors
:
nl
.
store
(
hbm_m_buffer
[
batch_id
,
head_id
*
q_h_per_k_h
+
i_q_h
,
nl
.
ds
(
i
*
B_P_SIZE
,
B_P_SIZE
),
],
m_buffer
[
i
,
i_q_h
,
:,
:],
)
nl
.
store
(
hbm_l_buffer
[
batch_id
,
head_id
*
q_h_per_k_h
+
i_q_h
,
nl
.
ds
(
i
*
B_P_SIZE
,
B_P_SIZE
),
],
l_buffer
[
i
,
i_q_h
],
)
nl
.
store
(
hbm_qk_res
[
batch_id
,
head_id
*
q_h_per_k_h
+
i_q_h
,
:,
:],
qk_res_buffer
[
batch_id
,
i_q_h
,
:,
:],
)
if
return_debug_tensors
:
return
o
,
hbm_m_buffer
,
hbm_l_buffer
,
hbm_qk_res
return
o
def
reorder_context_mask
(
mask
,
LARGE_TILE_SZ
,
block_size
):
"""
Reorder the mask to make it compatible with the flash attention kernel.
We vectorize KV cache read to improve DMA utilization. However, the layout
that maximizes DMA bandwidth changes the order tokens are consumed.
The token layout (inner 2 dimensions) after vectorized load is (B_P_SIZE,
tiled_block_size) in a tile of `B_P_SIZE * tiled_block_size` tokens. And
each step the engine consumes a column (rather than a row) of B_P_SIZE
tokens. Therefore, the tokens are visited in a strided way.
To make sure mask matches the order tokens are consumed, we need to properly
transpose mask.
"""
total_query_len
,
total_seq_len
=
mask
.
shape
context_kv_len
=
total_seq_len
-
total_query_len
B_P_SIZE
=
128
assert
(
LARGE_TILE_SZ
>=
B_P_SIZE
),
f
"
{
LARGE_TILE_SZ
=
}
must be larger than
{
B_P_SIZE
=
}
"
num_tiled_blocks
=
max
(
B_P_SIZE
,
LARGE_TILE_SZ
//
block_size
)
tiled_block_size
=
LARGE_TILE_SZ
//
num_tiled_blocks
if
tiled_block_size
>
1
:
# Mask reordering is needed when tiled_block_size > 1
device
=
mask
.
device
mask
=
mask
.
cpu
()
context_mask
=
mask
[:,
:
context_kv_len
]
context_mask
=
context_mask
.
view
(
total_query_len
,
context_kv_len
//
LARGE_TILE_SZ
,
num_tiled_blocks
//
B_P_SIZE
,
B_P_SIZE
,
tiled_block_size
,
)
context_mask
=
context_mask
.
transpose
(
3
,
4
).
reshape
(
total_query_len
,
context_kv_len
)
new_mask
=
mask
[:,
context_kv_len
:]
return
torch
.
concat
([
context_mask
,
new_mask
],
dim
=
1
).
to
(
device
)
else
:
return
mask
def
flash_attn_varlen_nkifunc
(
query
,
key
,
value
,
kv_cache
,
block_table
,
attn_mask
,
n_kv_head
=
None
,
head_size
=
None
,
LARGE_TILE_SZ
=
2048
,
mixed_precision
=
True
,
):
"""
Compute flash paged attention for variable length sequences.
This function is a wrapper around the flash attention NKI kernel. It takes
in the following arguments:
- query: (1, n_heads, d, seq_q)
- key: (1, n_kv_heads, d, seq_k)
- value: (1, n_kv_heads, seq_v, d)
- kv_cache: (2, n_blocks, n_kv_heads, block_size, d)
- block_tables: (n_active_blocks, )
- attn_mask: (seq_q, n_active_blocks * block_size + seq_q)
Notes:
- attn_mask must be reordered outside using `reorder_context_mask`
- Key/value cache layout must be (n_blocks, n_kv_heads, block_size, d)
for better DMA throughput
"""
if
n_kv_head
is
None
:
n_kv_head
=
kv_cache
.
shape
[
2
]
assert
kv_cache
.
shape
[
0
]
==
2
assert
kv_cache
.
shape
[
2
]
==
n_kv_head
if
head_size
is
None
:
head_size
=
kv_cache
.
shape
[
-
1
]
kwargs
=
dict
(
query
=
query
,
key
=
key
,
value
=
value
,
kv_cache
=
kv_cache
,
block_tables
=
block_table
,
mask
=
attn_mask
,
softmax_scale
=
1.0
/
(
head_size
**
0.5
),
mixed_precision
=
mixed_precision
,
LARGE_TILE_SZ
=
LARGE_TILE_SZ
,
)
o
=
flash_paged_attention
[
1
,
n_kv_head
](
**
kwargs
)
return
o
def
reshape_and_cache
(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
)
->
None
:
"""
Writes key-value pairs to the KV cache at specified positions.
Args:
key (torch.Tensor): Key tensor with shape
(num_tokens, n_kv_head, d_head)
value (torch.Tensor): Value tensor with shape
(num_tokens, n_kv_head, d_head)
kv_cache (torch.Tensor): Key/value cache tensor with shape
(2, num_blocks, n_kv_head, block_size, d_head)
slot_mapping (torch.Tensor): Mapping tensor indicating cache positions
with shape (num_tokens)
Returns:
None: Updates the kv_cache tensor in-place
"""
block_size
=
kv_cache
.
size
(
3
)
n_kv_head
=
key
.
size
(
1
)
# Calculate indices with explicit floor division
block_indices
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
"floor"
)
block_offsets
=
slot_mapping
%
block_size
# Create the head indices tensor
head_indices
=
torch
.
arange
(
n_kv_head
,
device
=
key
.
device
)
# Update caches using index_put_
kv_cache
.
index_put_
(
(
torch
.
tensor
([
0
],
device
=
key
.
device
),
block_indices
[:,
None
],
head_indices
[
None
,
:],
block_offsets
[:,
None
]),
key
)
kv_cache
.
index_put_
(
(
torch
.
tensor
([
1
],
device
=
key
.
device
),
block_indices
[:,
None
],
head_indices
[
None
,
:],
block_offsets
[:,
None
]),
value
)
vllm/collect_env.py
View file @
4172235a
...
@@ -54,7 +54,6 @@ SystemEnv = namedtuple(
...
@@ -54,7 +54,6 @@ SystemEnv = namedtuple(
'is_xnnpack_available'
,
'is_xnnpack_available'
,
'cpu_info'
,
'cpu_info'
,
'rocm_version'
,
# vllm specific field
'rocm_version'
,
# vllm specific field
'neuron_sdk_version'
,
# vllm specific field
'vllm_version'
,
# vllm specific field
'vllm_version'
,
# vllm specific field
'vllm_build_flags'
,
# vllm specific field
'vllm_build_flags'
,
# vllm specific field
'gpu_topo'
,
# vllm specific field
'gpu_topo'
,
# vllm specific field
...
@@ -275,15 +274,6 @@ def get_rocm_version(run_lambda):
...
@@ -275,15 +274,6 @@ def get_rocm_version(run_lambda):
r
'HIP version: (\S+)'
)
r
'HIP version: (\S+)'
)
def
get_neuron_sdk_version
(
run_lambda
):
# Adapted from your install script
try
:
result
=
run_lambda
([
"neuron-ls"
])
return
result
if
result
[
0
]
==
0
else
'N/A'
except
Exception
:
return
'N/A'
def
get_vllm_version
():
def
get_vllm_version
():
from
vllm
import
__version__
,
__version_tuple__
from
vllm
import
__version__
,
__version_tuple__
...
@@ -306,10 +296,9 @@ def get_vllm_version():
...
@@ -306,10 +296,9 @@ def get_vllm_version():
def
summarize_vllm_build_flags
():
def
summarize_vllm_build_flags
():
# This could be a static method if the flags are constant, or dynamic if you need to check environment variables, etc.
# This could be a static method if the flags are constant, or dynamic if you need to check environment variables, etc.
return
'CUDA Archs: {}; ROCm:
{}; Neuron:
{}'
.
format
(
return
'CUDA Archs: {}; ROCm: {}'
.
format
(
os
.
environ
.
get
(
'TORCH_CUDA_ARCH_LIST'
,
'Not Set'
),
os
.
environ
.
get
(
'TORCH_CUDA_ARCH_LIST'
,
'Not Set'
),
'Enabled'
if
os
.
environ
.
get
(
'ROCM_HOME'
)
else
'Disabled'
,
'Enabled'
if
os
.
environ
.
get
(
'ROCM_HOME'
)
else
'Disabled'
,
'Enabled'
if
os
.
environ
.
get
(
'NEURON_CORES'
)
else
'Disabled'
,
)
)
...
@@ -601,7 +590,6 @@ def get_env_info():
...
@@ -601,7 +590,6 @@ def get_env_info():
conda_packages
=
get_conda_packages
(
run_lambda
)
conda_packages
=
get_conda_packages
(
run_lambda
)
rocm_version
=
get_rocm_version
(
run_lambda
)
rocm_version
=
get_rocm_version
(
run_lambda
)
neuron_sdk_version
=
get_neuron_sdk_version
(
run_lambda
)
vllm_version
=
get_vllm_version
()
vllm_version
=
get_vllm_version
()
vllm_build_flags
=
summarize_vllm_build_flags
()
vllm_build_flags
=
summarize_vllm_build_flags
()
gpu_topo
=
get_gpu_topo
(
run_lambda
)
gpu_topo
=
get_gpu_topo
(
run_lambda
)
...
@@ -635,7 +623,6 @@ def get_env_info():
...
@@ -635,7 +623,6 @@ def get_env_info():
is_xnnpack_available
=
is_xnnpack_available
(),
is_xnnpack_available
=
is_xnnpack_available
(),
cpu_info
=
get_cpu_info
(
run_lambda
),
cpu_info
=
get_cpu_info
(
run_lambda
),
rocm_version
=
rocm_version
,
rocm_version
=
rocm_version
,
neuron_sdk_version
=
neuron_sdk_version
,
vllm_version
=
vllm_version
,
vllm_version
=
vllm_version
,
vllm_build_flags
=
vllm_build_flags
,
vllm_build_flags
=
vllm_build_flags
,
gpu_topo
=
gpu_topo
,
gpu_topo
=
gpu_topo
,
...
@@ -702,7 +689,6 @@ env_info_fmt += """
...
@@ -702,7 +689,6 @@ env_info_fmt += """
vLLM Info
vLLM Info
==============================
==============================
ROCM Version : {rocm_version}
ROCM Version : {rocm_version}
Neuron SDK Version : {neuron_sdk_version}
vLLM Version : {vllm_version}
vLLM Version : {vllm_version}
vLLM Build Flags:
vLLM Build Flags:
{vllm_build_flags}
{vllm_build_flags}
...
...
vllm/config/__init__.py
View file @
4172235a
...
@@ -461,11 +461,6 @@ class ModelConfig:
...
@@ -461,11 +461,6 @@ class ModelConfig:
DP (which is controlled by `--data-parallel-size`).
DP (which is controlled by `--data-parallel-size`).
This is only supported on a per-model basis and falls back to
This is only supported on a per-model basis and falls back to
`"weights"` if the encoder does not support DP."""
`"weights"` if the encoder does not support DP."""
override_neuron_config
:
dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
"""Initialize non-default neuron config or override default neuron config
that are specific to Neuron devices, this argument will be used to
configure the neuron config that can not be gathered from the vllm
arguments. e.g. `{"cast_logits_dtype": "bfloat16"}`."""
pooler_config
:
Optional
[
"PoolerConfig"
]
=
field
(
init
=
False
)
pooler_config
:
Optional
[
"PoolerConfig"
]
=
field
(
init
=
False
)
"""Pooler config which controls the behaviour of output pooling in pooling
"""Pooler config which controls the behaviour of output pooling in pooling
models."""
models."""
...
@@ -785,10 +780,6 @@ class ModelConfig:
...
@@ -785,10 +780,6 @@ class ModelConfig:
if
not
self
.
skip_tokenizer_init
:
if
not
self
.
skip_tokenizer_init
:
self
.
_verify_tokenizer_mode
()
self
.
_verify_tokenizer_mode
()
if
(
not
current_platform
.
is_neuron
()
and
self
.
override_neuron_config
):
raise
ValueError
(
"`override_neuron_config` is only supported on Neuron."
)
# Avoid running try_verify_and_update_config multiple times
# Avoid running try_verify_and_update_config multiple times
self
.
config_updated
=
False
self
.
config_updated
=
False
...
@@ -1696,13 +1687,7 @@ class ModelConfig:
...
@@ -1696,13 +1687,7 @@ class ModelConfig:
"""
"""
For Mllama, VLLM overrides HF's is_encoder_decoder flag and sets it to
For Mllama, VLLM overrides HF's is_encoder_decoder flag and sets it to
True to enable cross-attention
True to enable cross-attention
Neuron needs all multimodal data to be in the decoder and does not
need to explicitly enable cross-attention
"""
"""
if
(
current_platform
.
is_neuron
()
and
self
.
hf_config
.
model_type
==
"mllama"
):
return
False
return
is_encoder_decoder
(
self
.
hf_config
)
return
is_encoder_decoder
(
self
.
hf_config
)
@
property
@
property
...
@@ -1871,7 +1856,7 @@ class LoadConfig:
...
@@ -1871,7 +1856,7 @@ class LoadConfig:
self
.
ignore_patterns
=
[
"original/**/*"
]
self
.
ignore_patterns
=
[
"original/**/*"
]
Device
=
Literal
[
"auto"
,
"cuda"
,
"neuron"
,
"cpu"
,
"tpu"
,
"xpu"
]
Device
=
Literal
[
"auto"
,
"cuda"
,
"cpu"
,
"tpu"
,
"xpu"
]
@
config
@
config
...
@@ -1927,9 +1912,7 @@ class DeviceConfig:
...
@@ -1927,9 +1912,7 @@ class DeviceConfig:
self
.
device_type
=
self
.
device
.
type
self
.
device_type
=
self
.
device
.
type
# Some device types require processing inputs on CPU
# Some device types require processing inputs on CPU
if
self
.
device_type
in
[
"neuron"
]:
if
self
.
device_type
in
[
"tpu"
]:
self
.
device
=
torch
.
device
(
"cpu"
)
elif
self
.
device_type
in
[
"tpu"
]:
self
.
device
=
None
self
.
device
=
None
else
:
else
:
# Set device with device type
# Set device with device type
...
@@ -3941,7 +3924,6 @@ class VllmConfig:
...
@@ -3941,7 +3924,6 @@ class VllmConfig:
f
"skip_tokenizer_init=
{
self
.
model_config
.
skip_tokenizer_init
}
, "
f
"skip_tokenizer_init=
{
self
.
model_config
.
skip_tokenizer_init
}
, "
f
"tokenizer_mode=
{
self
.
model_config
.
tokenizer_mode
}
, "
f
"tokenizer_mode=
{
self
.
model_config
.
tokenizer_mode
}
, "
f
"revision=
{
self
.
model_config
.
revision
}
, "
f
"revision=
{
self
.
model_config
.
revision
}
, "
f
"override_neuron_config=
{
self
.
model_config
.
override_neuron_config
}
, "
# noqa
f
"tokenizer_revision=
{
self
.
model_config
.
tokenizer_revision
}
, "
f
"tokenizer_revision=
{
self
.
model_config
.
tokenizer_revision
}
, "
f
"trust_remote_code=
{
self
.
model_config
.
trust_remote_code
}
, "
f
"trust_remote_code=
{
self
.
model_config
.
trust_remote_code
}
, "
f
"dtype=
{
self
.
model_config
.
dtype
}
, "
f
"dtype=
{
self
.
model_config
.
dtype
}
, "
...
...
vllm/config/cache.py
View file @
4172235a
...
@@ -33,9 +33,8 @@ class CacheConfig:
...
@@ -33,9 +33,8 @@ class CacheConfig:
"""Configuration for the KV cache."""
"""Configuration for the KV cache."""
block_size
:
SkipValidation
[
BlockSize
]
=
None
# type: ignore
block_size
:
SkipValidation
[
BlockSize
]
=
None
# type: ignore
"""Size of a contiguous cache block in number of tokens. This is ignored on
"""Size of a contiguous cache block in number of tokens. On CUDA devices,
neuron devices and set to `--max-model-len`. On CUDA devices, only block
only block sizes up to 32 are supported.
sizes up to 32 are supported. On HPU devices, block size defaults to 128.
This config has no static default. If left unspecified by the user, it will
This config has no static default. If left unspecified by the user, it will
be set in `Platform.check_and_update_config()` based on the current
be set in `Platform.check_and_update_config()` based on the current
...
...
vllm/config/parallel.py
View file @
4172235a
...
@@ -377,10 +377,7 @@ class ParallelConfig:
...
@@ -377,10 +377,7 @@ class ParallelConfig:
from
vllm.executor
import
ray_utils
from
vllm.executor
import
ray_utils
backend
:
DistributedExecutorBackend
=
"mp"
backend
:
DistributedExecutorBackend
=
"mp"
ray_found
=
ray_utils
.
ray_is_available
()
ray_found
=
ray_utils
.
ray_is_available
()
if
current_platform
.
is_neuron
():
if
current_platform
.
is_tpu
()
and
envs
.
VLLM_XLA_USE_SPMD
:
# neuron uses single process to control multiple devices
backend
=
"uni"
elif
current_platform
.
is_tpu
()
and
envs
.
VLLM_XLA_USE_SPMD
:
backend
=
"uni"
backend
=
"uni"
elif
(
current_platform
.
is_cuda
()
elif
(
current_platform
.
is_cuda
()
and
cuda_device_count_stateless
()
<
self
.
world_size
):
and
cuda_device_count_stateless
()
<
self
.
world_size
):
...
...
vllm/distributed/device_communicators/neuron_communicator.py
deleted
100644 → 0
View file @
848562bd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
from
vllm.distributed.device_communicators.base_device_communicator
import
(
DeviceCommunicatorBase
)
from
vllm.platforms
import
current_platform
if
current_platform
.
is_neuron
():
import
torch_xla.core.xla_model
as
xm
class
NeuronCommunicator
(
DeviceCommunicatorBase
):
def
all_reduce
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
xm
.
all_reduce
(
xm
.
REDUCE_SUM
,
x
)
def
all_gather
(
self
,
x
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
assert
dim
==
-
1
,
"Neuron only supports dim=-1 for all-gather."
return
xm
.
all_gather
(
x
,
dim
=
dim
)
vllm/engine/arg_utils.py
View file @
4172235a
...
@@ -419,8 +419,6 @@ class EngineArgs:
...
@@ -419,8 +419,6 @@ class EngineArgs:
scheduling_policy
:
SchedulerPolicy
=
SchedulerConfig
.
policy
scheduling_policy
:
SchedulerPolicy
=
SchedulerConfig
.
policy
scheduler_cls
:
Union
[
str
,
Type
[
object
]]
=
SchedulerConfig
.
scheduler_cls
scheduler_cls
:
Union
[
str
,
Type
[
object
]]
=
SchedulerConfig
.
scheduler_cls
override_neuron_config
:
dict
[
str
,
Any
]
=
\
get_field
(
ModelConfig
,
"override_neuron_config"
)
override_pooler_config
:
Optional
[
Union
[
dict
,
PoolerConfig
]]
=
\
override_pooler_config
:
Optional
[
Union
[
dict
,
PoolerConfig
]]
=
\
ModelConfig
.
override_pooler_config
ModelConfig
.
override_pooler_config
compilation_config
:
CompilationConfig
=
\
compilation_config
:
CompilationConfig
=
\
...
@@ -561,8 +559,6 @@ class EngineArgs:
...
@@ -561,8 +559,6 @@ class EngineArgs:
help
=
model_kwargs
[
"hf_token"
][
"help"
])
help
=
model_kwargs
[
"hf_token"
][
"help"
])
model_group
.
add_argument
(
"--hf-overrides"
,
model_group
.
add_argument
(
"--hf-overrides"
,
**
model_kwargs
[
"hf_overrides"
])
**
model_kwargs
[
"hf_overrides"
])
model_group
.
add_argument
(
"--override-neuron-config"
,
**
model_kwargs
[
"override_neuron_config"
])
model_group
.
add_argument
(
"--override-pooler-config"
,
model_group
.
add_argument
(
"--override-pooler-config"
,
**
model_kwargs
[
"override_pooler_config"
])
**
model_kwargs
[
"override_pooler_config"
])
model_group
.
add_argument
(
"--logits-processor-pattern"
,
model_group
.
add_argument
(
"--logits-processor-pattern"
,
...
@@ -992,7 +988,6 @@ class EngineArgs:
...
@@ -992,7 +988,6 @@ class EngineArgs:
mm_processor_kwargs
=
self
.
mm_processor_kwargs
,
mm_processor_kwargs
=
self
.
mm_processor_kwargs
,
mm_processor_cache_gb
=
self
.
mm_processor_cache_gb
,
mm_processor_cache_gb
=
self
.
mm_processor_cache_gb
,
mm_encoder_tp_mode
=
self
.
mm_encoder_tp_mode
,
mm_encoder_tp_mode
=
self
.
mm_encoder_tp_mode
,
override_neuron_config
=
self
.
override_neuron_config
,
override_pooler_config
=
self
.
override_pooler_config
,
override_pooler_config
=
self
.
override_pooler_config
,
logits_processor_pattern
=
self
.
logits_processor_pattern
,
logits_processor_pattern
=
self
.
logits_processor_pattern
,
generation_config
=
self
.
generation_config
,
generation_config
=
self
.
generation_config
,
...
...
vllm/envs.py
View file @
4172235a
...
@@ -236,7 +236,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -236,7 +236,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# ================== Installation Time Env Vars ==================
# ================== Installation Time Env Vars ==================
# Target device of vLLM, supporting [cuda (by default),
# Target device of vLLM, supporting [cuda (by default),
# rocm,
neuron,
cpu]
# rocm, cpu]
"VLLM_TARGET_DEVICE"
:
"VLLM_TARGET_DEVICE"
:
lambda
:
os
.
getenv
(
"VLLM_TARGET_DEVICE"
,
"cuda"
).
lower
(),
lambda
:
os
.
getenv
(
"VLLM_TARGET_DEVICE"
,
"cuda"
).
lower
(),
...
...
vllm/model_executor/custom_op.py
View file @
4172235a
...
@@ -73,11 +73,6 @@ class CustomOp(nn.Module):
...
@@ -73,11 +73,6 @@ class CustomOp(nn.Module):
# NOTE(woosuk): This is a placeholder for future extensions.
# NOTE(woosuk): This is a placeholder for future extensions.
return
self
.
forward_native
(
*
args
,
**
kwargs
)
return
self
.
forward_native
(
*
args
,
**
kwargs
)
def
forward_neuron
(
self
,
*
args
,
**
kwargs
):
# By default, we assume that Neuron ops are compatible with the
# PyTorch-native implementation.
return
self
.
forward_native
(
*
args
,
**
kwargs
)
def
forward_oot
(
self
,
*
args
,
**
kwargs
):
def
forward_oot
(
self
,
*
args
,
**
kwargs
):
# By default, we assume that OOT ops are compatible with the
# By default, we assume that OOT ops are compatible with the
# PyTorch-native implementation.
# PyTorch-native implementation.
...
@@ -105,8 +100,6 @@ class CustomOp(nn.Module):
...
@@ -105,8 +100,6 @@ class CustomOp(nn.Module):
return
self
.
forward_tpu
return
self
.
forward_tpu
elif
current_platform
.
is_xpu
():
elif
current_platform
.
is_xpu
():
return
self
.
forward_xpu
return
self
.
forward_xpu
elif
current_platform
.
is_neuron
():
return
self
.
forward_neuron
elif
current_platform
.
is_out_of_tree
():
elif
current_platform
.
is_out_of_tree
():
return
self
.
forward_oot
return
self
.
forward_oot
else
:
else
:
...
...
vllm/model_executor/layers/activation.py
View file @
4172235a
...
@@ -95,13 +95,6 @@ class SiluAndMul(CustomOp):
...
@@ -95,13 +95,6 @@ class SiluAndMul(CustomOp):
self
.
op
(
out
,
x
)
self
.
op
(
out
,
x
)
return
out
return
out
def
forward_neuron
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
d
=
x
.
shape
[
-
1
]
//
2
x_reshaped
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
s
=
x_reshaped
[:,
:
d
]
*
F
.
sigmoid
(
x_reshaped
[:,
:
d
])
result
=
s
*
x_reshaped
[:,
d
:]
return
result
.
view
(
*
x
.
shape
[:
-
1
],
d
)
@
CustomOp
.
register
(
"mul_and_silu"
)
@
CustomOp
.
register
(
"mul_and_silu"
)
class
MulAndSilu
(
CustomOp
):
class
MulAndSilu
(
CustomOp
):
...
...
vllm/model_executor/layers/quantization/__init__.py
View file @
4172235a
...
@@ -26,7 +26,6 @@ QuantizationMethods = Literal[
...
@@ -26,7 +26,6 @@ QuantizationMethods = Literal[
"bitsandbytes"
,
"bitsandbytes"
,
"hqq"
,
"hqq"
,
"experts_int8"
,
"experts_int8"
,
"neuron_quant"
,
"ipex"
,
"ipex"
,
"quark"
,
"quark"
,
"moe_wna16"
,
"moe_wna16"
,
...
@@ -108,7 +107,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
...
@@ -108,7 +107,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
from
.modelopt
import
ModelOptFp8Config
,
ModelOptNvFp4Config
from
.modelopt
import
ModelOptFp8Config
,
ModelOptNvFp4Config
from
.moe_wna16
import
MoeWNA16Config
from
.moe_wna16
import
MoeWNA16Config
from
.mxfp4
import
Mxfp4Config
from
.mxfp4
import
Mxfp4Config
from
.neuron_quant
import
NeuronQuantConfig
from
.petit
import
PetitNvFp4Config
from
.petit
import
PetitNvFp4Config
from
.ptpc_fp8
import
PTPCFp8Config
from
.ptpc_fp8
import
PTPCFp8Config
from
.rtn
import
RTNConfig
from
.rtn
import
RTNConfig
...
@@ -135,7 +133,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
...
@@ -135,7 +133,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"ptpc_fp8"
:
PTPCFp8Config
,
"ptpc_fp8"
:
PTPCFp8Config
,
"hqq"
:
HQQMarlinConfig
,
"hqq"
:
HQQMarlinConfig
,
"experts_int8"
:
ExpertsInt8Config
,
"experts_int8"
:
ExpertsInt8Config
,
"neuron_quant"
:
NeuronQuantConfig
,
"ipex"
:
IPEXConfig
,
"ipex"
:
IPEXConfig
,
"quark"
:
QuarkConfig
,
"quark"
:
QuarkConfig
,
"moe_wna16"
:
MoeWNA16Config
,
"moe_wna16"
:
MoeWNA16Config
,
...
...
vllm/model_executor/layers/quantization/neuron_quant.py
deleted
100644 → 0
View file @
848562bd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
os
from
importlib.util
import
find_spec
from
typing
import
Any
,
Optional
from
torch.nn
import
Module
from
vllm.model_executor.layers.quantization
import
QuantizationMethods
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
SUPPORTED_QUANT_DTYPE_LIST
=
[
's8'
,
'f8e4m3fn'
]
class
AlwaysSupportedDtypes
(
list
):
def
__contains__
(
self
,
item
):
return
True
class
NeuronQuantConfig
(
QuantizationConfig
):
"""Int8 Quantization Config class for Neuron Backend."""
def
__init__
(
self
,
dequant_dtype
:
str
=
"f16"
,
quantize_method
:
str
=
"vector_dynamic"
,
)
->
None
:
super
().
__init__
()
self
.
quant_dtype
=
os
.
getenv
(
"NEURON_QUANT_DTYPE"
,
"s8"
)
if
self
.
quant_dtype
not
in
SUPPORTED_QUANT_DTYPE_LIST
:
raise
ValueError
(
f
"Neuron quantization datatype
{
self
.
quant_dtype
}
is not valid,"
f
" the quantization datatype should match one of the below "
f
"types
{
SUPPORTED_QUANT_DTYPE_LIST
}
"
)
self
.
dequant_dtype
=
dequant_dtype
self
.
quantize_method
=
quantize_method
def
get_name
(
self
)
->
QuantizationMethods
:
return
"neuron_quant"
def
get_supported_act_dtypes
(
self
)
->
list
[
str
]:
# Neuron implements custom handling logic for quantization support
return
AlwaysSupportedDtypes
()
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
raise
NotImplementedError
(
"This function should not be called with Neuron Backend"
)
@
staticmethod
def
get_config_filenames
()
->
list
[
str
]:
return
[]
@
classmethod
def
from_config
(
cls
,
config
:
dict
[
str
,
Any
])
->
"NeuronQuantConfig"
:
quantize_method
=
cls
.
get_from_keys
(
config
,
[
"quantize_method"
])
dequant_dtype
=
cls
.
get_from_keys
(
config
,
[
"dequant_dtype"
])
return
cls
(
dequant_dtype
=
dequant_dtype
,
quantize_method
=
quantize_method
)
def
get_quant_method
(
self
,
layer
:
Module
,
prefix
:
str
)
->
Optional
[
Any
]:
if
find_spec
(
"transformers_neuronx"
)
is
not
None
:
return
self
.
get_quantization_config
()
else
:
raise
NotImplementedError
(
"Neuron Quantization is only supported through"
" transformers_neuronx."
)
def
get_quantization_config
(
self
):
from
transformers_neuronx.config
import
QuantizationConfig
return
QuantizationConfig
(
quant_dtype
=
self
.
quant_dtype
,
dequant_dtype
=
self
.
dequant_dtype
,
quantize_method
=
self
.
quantize_method
)
vllm/model_executor/layers/rotary_embedding/base.py
View file @
4172235a
...
@@ -7,7 +7,7 @@ import torch
...
@@ -7,7 +7,7 @@ import torch
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.custom_op
import
CustomOp
from
.common
import
apply_rotary_emb_dispatch
,
apply_rotary_emb_torch
from
.common
import
apply_rotary_emb_torch
@
CustomOp
.
register
(
"rotary_embedding"
)
@
CustomOp
.
register
(
"rotary_embedding"
)
...
@@ -149,87 +149,6 @@ class RotaryEmbedding(CustomOp):
...
@@ -149,87 +149,6 @@ class RotaryEmbedding(CustomOp):
self
.
cos_sin_cache
,
self
.
is_neox_style
)
self
.
cos_sin_cache
,
self
.
is_neox_style
)
return
query
,
key
return
query
,
key
def
forward_neuron
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
def
_apply_rotary_emb_neuron
(
x
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
is_neox_style
:
bool
,
)
->
torch
.
Tensor
:
cos
=
cos
.
unsqueeze
(
-
2
).
to
(
x
.
dtype
)
sin
=
sin
.
unsqueeze
(
-
2
).
to
(
x
.
dtype
)
if
is_neox_style
:
x1
,
x2
=
torch
.
chunk
(
x
,
2
,
dim
=-
1
)
else
:
# x1 = x[..., ::2]
# x2 = x[..., 1::2]
d
=
x
.
shape
[
-
1
]
//
2
x_reshaped
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
x1
=
x_reshaped
[:,
::
2
].
view
(
*
x
.
shape
[:
-
1
],
d
)
x2
=
x_reshaped
[:,
1
::
2
].
view
(
*
x
.
shape
[:
-
1
],
d
)
o1
=
x1
*
cos
-
x2
*
sin
o2
=
x2
*
cos
+
x1
*
sin
if
is_neox_style
:
return
torch
.
cat
((
o1
,
o2
),
dim
=-
1
)
else
:
return
torch
.
stack
((
o1
,
o2
),
dim
=-
1
).
flatten
(
-
2
)
if
offsets
is
not
None
:
positions
=
positions
+
offsets
self
.
cos_sin_cache
=
self
.
cos_sin_cache
.
to
(
query
.
device
,
dtype
=
query
.
dtype
)
positions
=
positions
.
flatten
()
num_tokens
=
positions
.
shape
[
0
]
cos_sin
=
self
.
cos_sin_cache
.
index_select
(
0
,
positions
)
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
query_shape
=
query
.
shape
query
=
query
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
if
key
is
not
None
:
key_shape
=
key
.
shape
key
=
key
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
if
self
.
rotary_dim
==
self
.
head_size
:
query
=
apply_rotary_emb_dispatch
(
query
,
cos
,
sin
,
self
.
is_neox_style
)
query
=
query
.
reshape
(
query_shape
)
if
key
is
not
None
:
key
=
apply_rotary_emb_dispatch
(
key
,
cos
,
sin
,
self
.
is_neox_style
)
key
=
key
.
reshape
(
key_shape
)
else
:
head_size
=
query
.
shape
[
-
1
]
query_reshaped
=
query
.
view
(
-
1
,
head_size
)
query_pass
=
query_reshaped
[:,
self
.
rotary_dim
:].
view
(
*
query
.
shape
[:
-
1
],
head_size
-
self
.
rotary_dim
)
query_rot
=
query_reshaped
[:,
:
self
.
rotary_dim
].
view
(
*
query
.
shape
[:
-
1
],
self
.
rotary_dim
)
query_rot
=
_apply_rotary_emb_neuron
(
query_rot
,
cos
,
sin
,
self
.
is_neox_style
)
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
).
reshape
(
query_shape
)
if
key
is
not
None
:
key_reshaped
=
key
.
view
(
-
1
,
head_size
)
key_pass
=
key_reshaped
[:,
self
.
rotary_dim
:].
view
(
*
key
.
shape
[:
-
1
],
head_size
-
self
.
rotary_dim
)
key_rot
=
key_reshaped
[:,
:
self
.
rotary_dim
].
view
(
*
key
.
shape
[:
-
1
],
self
.
rotary_dim
)
key_rot
=
_apply_rotary_emb_neuron
(
key_rot
,
cos
,
sin
,
self
.
is_neox_style
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
return
query
,
key
def
extra_repr
(
self
)
->
str
:
def
extra_repr
(
self
)
->
str
:
s
=
f
"head_size=
{
self
.
head_size
}
, rotary_dim=
{
self
.
rotary_dim
}
"
s
=
f
"head_size=
{
self
.
head_size
}
, rotary_dim=
{
self
.
rotary_dim
}
"
s
+=
f
", max_position_embeddings=
{
self
.
max_position_embeddings
}
"
s
+=
f
", max_position_embeddings=
{
self
.
max_position_embeddings
}
"
...
...
vllm/model_executor/model_loader/neuron.py
deleted
100644 → 0
View file @
848562bd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Utilities for selecting and loading Neuron models in transformers-neuronx
framework."""
import
ast
import
copy
import
importlib
import
os
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
from
transformers
import
PretrainedConfig
from
vllm.config
import
(
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
SpeculativeConfig
)
from
vllm.logprobs
import
Logprob
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
get_quantization_config
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
CompletionSequenceGroupOutput
,
SequenceOutput
TORCH_DTYPE_TO_NEURON_AMP
=
{
"auto"
:
"f32"
,
"half"
:
"f16"
,
"float16"
:
"f16"
,
"bfloat16"
:
"bf16"
,
"float"
:
"f32"
,
"float32"
:
"f32"
,
torch
.
float16
:
"f16"
,
torch
.
bfloat16
:
"bf16"
,
torch
.
float32
:
"f32"
,
}
# Models supported by Neuron.
_NEURON_SUPPORTED_MODELS
:
dict
[
str
,
tuple
[
str
,
str
,
str
]]
=
{
"LlamaForCausalLM"
:
(
"transformers_neuronx.llama.model"
,
"LlamaForSampling"
,
"LlamaForCausalLM"
),
"MistralForCausalLM"
:
(
"transformers_neuronx.mistral.model"
,
"MistralForSampling"
,
"MistralForCausalLM"
)
}
class
NeuronCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
on_device_sampling_disabled
:
bool
=
False
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
,
logits_as_input
=
True
)
self
.
on_device_sampling_disabled
=
on_device_sampling_disabled
if
self
.
on_device_sampling_disabled
:
# Use default sampler
self
.
sampler
=
Sampler
()
# Lazy initialized
self
.
model
:
nn
.
Module
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_block_ids
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
logits
=
self
.
model
(
input_ids
,
cache_ids
=
positions
,
start_ids
=
input_block_ids
)
return
logits
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
None
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
if
self
.
on_device_sampling_disabled
:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
# On-device sampling outputs the token ids directly.
sampled_token_ids
=
logits
.
flatten
()
next_tokens
=
[]
sample_idx
=
0
for
seq_group
in
sampling_metadata
.
seq_groups
:
samples
=
[]
for
seq_id
in
seq_group
.
seq_ids
:
token_id
=
sampled_token_ids
[
sample_idx
].
item
()
samples
.
append
(
SequenceOutput
(
parent_seq_id
=
seq_id
,
output_token
=
token_id
,
logprobs
=
{
token_id
:
Logprob
(
token_id
)}))
sample_idx
+=
1
next_tokens
.
append
(
CompletionSequenceGroupOutput
(
samples
=
samples
,
prompt_logprobs
=
None
))
return
SamplerOutput
(
outputs
=
next_tokens
)
def
load_weights
(
self
,
model_name_or_path
:
str
,
**
kwargs
):
arch
=
_get_model_architecture
(
self
.
config
)
neuronx_module_path
,
neuronx_model_cls_name
,
hf_model_cls_name
=
(
_NEURON_SUPPORTED_MODELS
[
arch
])
neuronx_module
=
importlib
.
import_module
(
neuronx_module_path
)
neuronx_model_cls
=
getattr
(
neuronx_module
,
neuronx_model_cls_name
)
self
.
model
=
neuronx_model_cls
.
from_pretrained
(
model_name_or_path
,
**
kwargs
)
self
.
model
.
to_neuron
()
class
NeuronSpeculationCausalLM
(
nn
.
Module
):
"""A Neuron-optimized causal language model with speculative decoding."""
SPECULATION_TERMINATION_ID
=
-
1
def
__init__
(
self
,
speculation_model
)
->
None
:
super
().
__init__
()
self
.
model
=
speculation_model
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_block_ids
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
tokens
,
counts
=
self
.
model
.
speculative_iteration
(
input_ids
,
positions
,
input_block_ids
)
# Mark the end of accepted speculative tokens for each sequence with the
# speculation termination id.
batch_size
,
steps
=
tokens
.
shape
mask
=
torch
.
arange
(
steps
).
expand
(
batch_size
,
-
1
)
>=
counts
tokens
[
mask
]
=
self
.
SPECULATION_TERMINATION_ID
return
tokens
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
list
[
SamplerOutput
]]:
batch_size
,
num_steps
=
logits
.
shape
seq_ids
=
[
seq_id
for
sg
in
sampling_metadata
.
seq_groups
for
seq_id
in
sg
.
seq_ids
]
# Organize input tensors by step instead of by sequence.
accepted_token_ids_by_step
=
logits
.
transpose
(
0
,
1
)
accepted_token_ids_by_step
=
accepted_token_ids_by_step
.
tolist
()
sampler_output_list
=
[]
for
step_index
in
range
(
num_steps
):
if
all
(
token_id
==
self
.
SPECULATION_TERMINATION_ID
for
token_id
in
accepted_token_ids_by_step
[
step_index
]):
break
step_output_token_ids
=
[]
for
sequence_index
in
range
(
batch_size
):
token_id
=
accepted_token_ids_by_step
[
step_index
][
sequence_index
]
step_output_token_ids
.
append
(
CompletionSequenceGroupOutput
(
samples
=
[
SequenceOutput
(
parent_seq_id
=
seq_ids
[
sequence_index
],
output_token
=
token_id
,
logprobs
=
{
token_id
:
Logprob
(
token_id
)})
],
prompt_logprobs
=
None
))
sampler_output_list
.
append
(
SamplerOutput
(
outputs
=
step_output_token_ids
))
return
sampler_output_list
def
_get_model_architecture
(
config
:
PretrainedConfig
)
->
str
:
architectures
=
getattr
(
config
,
"architectures"
,
[])
for
arch
in
architectures
:
if
arch
in
_NEURON_SUPPORTED_MODELS
:
return
arch
raise
ValueError
(
f
"Model architectures
{
architectures
}
are not supported on Neuron "
f
"for now. Supported architectures: "
f
"
{
list
(
_NEURON_SUPPORTED_MODELS
.
keys
())
}
"
)
def
_get_buckets
(
env
:
str
,
default_value
:
list
[
int
])
->
list
[
int
]:
env_value
=
os
.
getenv
(
env
)
if
env_value
is
None
:
return
default_value
buckets_remove_empty
=
filter
(
lambda
x
:
x
is
not
None
and
len
(
x
.
strip
())
>
0
,
env_value
.
split
(
","
))
buckets_int
=
map
(
int
,
buckets_remove_empty
)
buckets_list
=
list
(
buckets_int
)
return
buckets_list
def
_get_default_neuron_config
(
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
):
"""Generate a neuron config based on vllm config args."""
from
transformers_neuronx.config
import
ContinuousBatchingConfig
from
transformers_neuronx.constants
import
LAYOUT_BSH
continuous_batching_config
=
ContinuousBatchingConfig
(
batch_size_for_shared_caches
=
scheduler_config
.
max_num_seqs
)
quant_config
=
dict
(
dequant_dtype
=
TORCH_DTYPE_TO_NEURON_AMP
[
model_config
.
dtype
],
quantize_method
=
"vector_dynamic"
)
neuron_quantization_config_builder
=
lambda
quant
:
get_quantization_config
(
quant
).
from_config
(
quant_config
).
get_quant_method
(
None
,
""
)
# TODO: Add Paged attention config to the default neuron arguments.
default_neuron_args
=
dict
(
collectives_layout
=
LAYOUT_BSH
,
attention_layout
=
LAYOUT_BSH
,
fuse_qkv
=
True
,
quant
=
neuron_quantization_config_builder
(
model_config
.
quantization
)
if
model_config
.
quantization
else
None
,
continuous_batching
=
continuous_batching_config
,
weight_tiling
=
bool
(
model_config
.
quantization
),
on_device_generation
=
_get_neuron_on_device_generation_config
(
model_config
))
return
default_neuron_args
def
_get_default_neuron_config_for_speculation
(
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
):
"""Generate a neuron config for speculative decoding based on
vllm config args."""
from
transformers_neuronx.config
import
ContinuousBatchingConfig
from
transformers_neuronx.constants
import
LAYOUT_BSH
continuous_batching_config
=
ContinuousBatchingConfig
(
batch_size_for_shared_caches
=
scheduler_config
.
max_num_seqs
)
default_neuron_args
=
dict
(
collectives_layout
=
LAYOUT_BSH
,
attention_layout
=
LAYOUT_BSH
,
fuse_qkv
=
True
,
on_device_embedding
=
True
,
continuous_batching
=
continuous_batching_config
,
on_device_generation
=
copy
.
deepcopy
(
model_config
.
neuron_sampling_params
))
return
default_neuron_args
def
_get_neuron_on_device_generation_config
(
model_config
:
ModelConfig
):
if
not
_is_neuron_on_device_sampling_disabled
(
model_config
):
return
copy
.
deepcopy
(
model_config
.
neuron_sampling_params
)
return
None
def
_is_neuron_on_device_sampling_disabled
(
model_config
:
ModelConfig
)
->
bool
:
return
not
getattr
(
model_config
,
"neuron_sampling_params"
,
None
)
def
_get_neuron_config_after_override
(
default_neuron_config
,
overridden_neuron_config
):
from
transformers_neuronx.config
import
(
ContinuousBatchingConfig
,
GenerationConfig
,
KVCacheQuantizationConfig
,
NeuronConfig
,
QuantizationConfig
,
SparseAttnConfig
)
sparse_attn
=
overridden_neuron_config
.
pop
(
"sparse_attn"
,
{})
if
sparse_attn
:
overridden_neuron_config
[
"sparse_attn"
]
=
SparseAttnConfig
(
**
sparse_attn
)
kv_cache_quant
=
overridden_neuron_config
.
pop
(
"kv_cache_quant"
,
{})
if
kv_cache_quant
:
overridden_neuron_config
[
"kv_cache_quant"
]
=
KVCacheQuantizationConfig
(
**
kv_cache_quant
)
continuous_batching
=
overridden_neuron_config
.
pop
(
"continuous_batching"
,
{})
if
continuous_batching
:
overridden_neuron_config
[
"continuous_batching"
]
=
ContinuousBatchingConfig
(
**
continuous_batching
)
quant
=
overridden_neuron_config
.
pop
(
"quant"
,
{})
if
quant
:
overridden_neuron_config
[
"quant"
]
=
QuantizationConfig
(
**
quant
)
on_device_generation
=
overridden_neuron_config
.
pop
(
"on_device_generation"
,
{})
if
on_device_generation
:
overridden_neuron_config
[
"on_device_generation"
]
=
GenerationConfig
(
**
on_device_generation
)
default_neuron_config
.
update
(
overridden_neuron_config
)
return
NeuronConfig
(
**
default_neuron_config
)
def
get_neuron_model
(
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
)
->
nn
.
Module
:
"""Initializes a neuron-optimized model for inference."""
# Create a model instance.
model
=
NeuronCausalLM
(
model_config
.
hf_config
,
_is_neuron_on_device_sampling_disabled
(
model_config
))
default_neuron_config_args
=
_get_default_neuron_config
(
model_config
,
parallel_config
,
scheduler_config
)
neuron_config
=
_get_neuron_config_after_override
(
default_neuron_config_args
,
model_config
.
override_neuron_config
)
context_length_estimates
=
_get_buckets
(
"NEURON_CONTEXT_LENGTH_BUCKETS"
,
[
scheduler_config
.
max_model_len
])
n_positions
=
_get_buckets
(
"NEURON_TOKEN_GEN_BUCKETS"
,
[
scheduler_config
.
max_model_len
])
model
.
load_weights
(
model_config
.
model
,
tp_degree
=
parallel_config
.
tensor_parallel_size
,
amp
=
TORCH_DTYPE_TO_NEURON_AMP
[
model_config
.
dtype
],
neuron_config
=
neuron_config
,
context_length_estimate
=
context_length_estimates
,
n_positions
=
n_positions
,
batch_size
=
scheduler_config
.
max_num_seqs
)
return
model
.
eval
()
def
get_neuron_speculation_model
(
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
speculation_config
:
SpeculativeConfig
):
"""Initializes a neuron-optimized speculation model for inference.
This method is only applicable for speculation with a standalone draft model
"""
from
transformers_neuronx.fused_speculation
import
FusedSpeculativeDecoder
# For Eagle SD, we need to pass in additional parameters in neuron config.
is_eagle
=
getattr
(
speculation_config
.
draft_model_config
.
hf_config
,
"is_eagle"
,
False
)
# Create target model instance.
target_model
=
NeuronCausalLM
(
model_config
.
hf_config
)
default_neuron_config_args
=
_get_default_neuron_config_for_speculation
(
model_config
,
parallel_config
,
scheduler_config
)
if
is_eagle
:
default_neuron_config_args
[
'is_eagle_target'
]
=
True
neuron_config
=
_get_neuron_config_after_override
(
default_neuron_config_args
,
model_config
.
override_neuron_config
)
context_length_estimates
=
_get_buckets
(
"NEURON_CONTEXT_LENGTH_BUCKETS"
,
[
scheduler_config
.
max_model_len
])
n_positions
=
_get_buckets
(
"NEURON_TOKEN_GEN_BUCKETS"
,
[
scheduler_config
.
max_model_len
])
target_model
.
load_weights
(
model_config
.
model
,
tp_degree
=
parallel_config
.
tensor_parallel_size
,
amp
=
TORCH_DTYPE_TO_NEURON_AMP
[
model_config
.
dtype
],
neuron_config
=
neuron_config
,
context_length_estimate
=
context_length_estimates
,
n_positions
=
n_positions
,
batch_size
=
scheduler_config
.
max_num_seqs
)
target_model
.
eval
()
# Create draft model instance.
draft_model
=
NeuronCausalLM
(
speculation_config
.
draft_model_config
.
hf_config
)
default_draft_neuron_config_args
=
(
_get_default_neuron_config_for_speculation
(
speculation_config
.
draft_model_config
,
parallel_config
,
scheduler_config
))
if
is_eagle
:
default_draft_neuron_config_args
[
'is_eagle_draft'
]
=
True
default_draft_neuron_config_args
[
'has_pre_attention_norm'
]
=
False
draft_neuron_config
=
_get_neuron_config_after_override
(
default_draft_neuron_config_args
,
speculation_config
.
draft_model_config
.
override_neuron_config
)
draft_model
.
load_weights
(
speculation_config
.
draft_model_config
.
model
,
tp_degree
=
speculation_config
.
draft_parallel_config
.
tensor_parallel_size
,
amp
=
TORCH_DTYPE_TO_NEURON_AMP
[
speculation_config
.
draft_model_config
.
dtype
],
neuron_config
=
draft_neuron_config
,
context_length_estimate
=
context_length_estimates
,
n_positions
=
n_positions
,
batch_size
=
scheduler_config
.
max_num_seqs
)
draft_model
.
eval
()
num_speculative_tokens
=
speculation_config
.
num_speculative_tokens
# Create speculation model instance.
speculation_model
=
FusedSpeculativeDecoder
(
draft_model
.
model
,
target_model
.
model
,
num_speculative_tokens
)
speculation_model
.
to_neuron
()
return
NeuronSpeculationCausalLM
(
speculation_model
)
def
get_neuron_eagle_speculation_model
(
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
speculation_config
:
SpeculativeConfig
):
"""Initializes a neuron-optimized EAGLE speculation model for inference."""
from
transformers_neuronx.eagle_speculation
import
EagleSpeculativeDecoder
# Create target model instance.
target_model
=
NeuronCausalLM
(
model_config
.
hf_config
)
default_neuron_config_args
=
_get_default_neuron_config_for_speculation
(
model_config
,
parallel_config
,
scheduler_config
)
default_neuron_config_args
[
'is_eagle_target'
]
=
True
neuron_config
=
_get_neuron_config_after_override
(
default_neuron_config_args
,
model_config
.
override_neuron_config
)
context_length_estimates
=
_get_buckets
(
"NEURON_CONTEXT_LENGTH_BUCKETS"
,
[
scheduler_config
.
max_model_len
])
n_positions
=
_get_buckets
(
"NEURON_TOKEN_GEN_BUCKETS"
,
[
scheduler_config
.
max_model_len
])
target_model
.
load_weights
(
model_config
.
model
,
tp_degree
=
parallel_config
.
tensor_parallel_size
,
amp
=
TORCH_DTYPE_TO_NEURON_AMP
[
model_config
.
dtype
],
neuron_config
=
neuron_config
,
context_length_estimate
=
context_length_estimates
,
n_positions
=
n_positions
,
batch_size
=
scheduler_config
.
max_num_seqs
)
target_model
.
eval
()
# Create draft model instance.
draft_model
=
NeuronCausalLM
(
speculation_config
.
draft_model_config
.
hf_config
)
default_draft_neuron_config_args
=
(
_get_default_neuron_config_for_speculation
(
speculation_config
.
draft_model_config
,
parallel_config
,
scheduler_config
))
default_draft_neuron_config_args
[
'is_eagle_draft'
]
=
True
default_draft_neuron_config_args
[
'has_pre_attention_norm'
]
=
False
draft_neuron_config
=
_get_neuron_config_after_override
(
default_draft_neuron_config_args
,
speculation_config
.
draft_model_config
.
override_neuron_config
)
draft_model
.
load_weights
(
speculation_config
.
draft_model_config
.
model
,
tp_degree
=
speculation_config
.
draft_parallel_config
.
tensor_parallel_size
,
amp
=
TORCH_DTYPE_TO_NEURON_AMP
[
speculation_config
.
draft_model_config
.
dtype
],
neuron_config
=
draft_neuron_config
,
context_length_estimate
=
context_length_estimates
,
n_positions
=
n_positions
,
batch_size
=
scheduler_config
.
max_num_seqs
)
draft_model
.
eval
()
token_tree
:
dict
[
int
,
list
[
int
]]
=
ast
.
literal_eval
(
speculation_config
.
speculative_token_tree
)
speculation_model
=
EagleSpeculativeDecoder
(
draft_model
.
model
,
target_model
.
model
,
token_tree
=
token_tree
)
speculation_model
.
to_neuron
()
return
NeuronSpeculationCausalLM
(
speculation_model
)
vllm/model_executor/model_loader/neuronx_distributed.py
deleted
100644 → 0
View file @
848562bd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Utilities for selecting and loading Neuron models in
neuronx-distributed-inference framework."""
# Disabling yapf because yapf and isort have conflicts for the below imports
# yapf: disable
import
copy
import
hashlib
import
importlib
import
multiprocessing
import
os
import
shutil
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
from
neuronx_distributed_inference.models.config
import
(
FusedSpecNeuronConfig
,
OnDeviceSamplingConfig
)
from
neuronx_distributed_inference.models.mllama.utils
import
(
create_vision_mask
)
from
neuronx_distributed_inference.modules.lora_serving
import
(
LoraServingConfig
)
from
neuronx_distributed_inference.utils.hf_adapter
import
(
load_pretrained_config
)
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
,
PretrainedConfig
from
vllm.config
import
(
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
SpeculativeConfig
)
from
vllm.logger
import
init_logger
from
vllm.logprobs
import
Logprob
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
CompletionSequenceGroupOutput
,
SequenceOutput
# yapf: enable
logger
=
init_logger
(
__name__
)
TORCH_DTYPE_TO_NEURON_AMP
=
{
"auto"
:
"float32"
,
"half"
:
"float16"
,
"float16"
:
"float16"
,
"bfloat16"
:
"bfloat16"
,
"float"
:
"float32"
,
"float32"
:
"float32"
,
torch
.
float16
:
"float16"
,
torch
.
bfloat16
:
"bfloat16"
,
torch
.
float32
:
"float32"
,
}
# Models supported by Neuronx distributed for inference.
_NEURON_SUPPORTED_MODELS
:
dict
[
str
,
tuple
[
str
,
str
]]
=
{
"LlamaForCausalLM"
:
(
"neuronx_distributed_inference.models.llama.modeling_llama"
,
"NeuronLlamaForCausalLM"
),
"MistralForCausalLM"
:
(
"neuronx_distributed_inference.models.llama.modeling_llama"
,
"NeuronLlamaForCausalLM"
),
"DbrxForCausalLM"
:
(
"neuronx_distributed_inference.models.dbrx.modeling_dbrx"
,
"NeuronDbrxForCausalLM"
),
"MixtralForCausalLM"
:
(
"neuronx_distributed_inference.models.mixtral.modeling_mixtral"
,
"NeuronMixtralForCausalLM"
),
"MllamaForConditionalGeneration"
:
(
"neuronx_distributed_inference.models.mllama.modeling_mllama"
,
"NeuronMllamaForCausalLM"
),
}
class
NeuronCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
,
logits_as_input
=
True
)
self
.
sampler
=
Sampler
()
# Lazy initialized
self
.
model
:
nn
.
Module
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_block_ids
:
torch
.
Tensor
,
sampling_params
:
torch
.
Tensor
,
prev_hidden
:
Optional
[
torch
.
Tensor
]
=
None
,
adapter_ids
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
# sort block ids sequentially for perf/neuron support reasons
sorted_input_block_ids
,
sorted_indices
=
torch
.
sort
(
input_block_ids
)
input_ids
=
torch
.
index_select
(
input_ids
,
0
,
sorted_indices
)
positions
=
torch
.
index_select
(
positions
,
0
,
sorted_indices
)
sampling_params
=
torch
.
index_select
(
sampling_params
,
0
,
sorted_indices
)
output
=
self
.
model
(
input_ids
,
attention_mask
=
None
,
position_ids
=
positions
,
seq_ids
=
sorted_input_block_ids
,
sampling_params
=
sampling_params
,
prev_hidden
=
prev_hidden
,
adapter_ids
=
adapter_ids
)
# on-device sampling
if
self
.
config
.
neuron_config
.
on_device_sampling_config
:
output
=
output
.
hidden_states
else
:
output
=
output
.
logits
[:,
-
1
,
:]
restored_indices
=
torch
.
argsort
(
sorted_indices
)
if
input_block_ids
.
shape
[
0
]
!=
1
:
output
=
torch
.
index_select
(
output
,
0
,
restored_indices
)
return
output
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
None
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
# on-device sampling
if
self
.
config
.
neuron_config
.
on_device_sampling_config
:
batch_size
=
logits
.
shape
seq_ids
=
[
seq_id
for
sg
in
sampling_metadata
.
seq_groups
for
seq_id
in
sg
.
seq_ids
]
assert
len
(
seq_ids
)
==
list
(
batch_size
)[
0
],
"batch size mismatch"
# Organize input tensors by step instead of by sequence.
accepted_token_ids_by_step
=
logits
.
flatten
()
accepted_token_ids_by_step
=
accepted_token_ids_by_step
.
tolist
()
step_output_token_ids
=
[]
for
i
,
seq_id
in
enumerate
(
seq_ids
):
token_id
=
accepted_token_ids_by_step
[
i
]
step_output_token_ids
.
append
(
CompletionSequenceGroupOutput
(
samples
=
[
SequenceOutput
(
parent_seq_id
=
seq_id
,
output_token
=
token_id
,
logprobs
=
{
token_id
:
Logprob
(
token_id
)})
],
prompt_logprobs
=
None
))
return
SamplerOutput
(
outputs
=
step_output_token_ids
)
else
:
return
self
.
sampler
(
logits
,
sampling_metadata
)
def
load_weights
(
self
,
model_name_or_path
:
str
,
**
kwargs
):
arch
=
_get_model_architecture
(
self
.
config
)
neuronx_module_path
,
neuronx_model_cls_name
=
(
_NEURON_SUPPORTED_MODELS
[
arch
])
neuronx_module
=
importlib
.
import_module
(
neuronx_module_path
)
neuronx_model_cls
=
getattr
(
neuronx_module
,
neuronx_model_cls_name
)
neuron_config
=
neuronx_model_cls
.
get_neuron_config_cls
()(
**
kwargs
[
'neuron_config'
])
self
.
config
.
neuron_config
=
neuron_config
config
=
neuronx_model_cls
.
get_config_cls
()(
neuron_config
,
load_config
=
load_pretrained_config
(
model_name_or_path
))
hashed_config
=
hashlib
.
md5
(
config
.
to_json_string
().
encode
(
'utf-8'
),
usedforsecurity
=
False
).
hexdigest
()
if
os
.
getenv
(
"NEURON_COMPILED_ARTIFACTS"
)
is
not
None
:
compiled_model_path
=
os
.
getenv
(
"NEURON_COMPILED_ARTIFACTS"
)
elif
os
.
path
.
exists
(
model_name_or_path
):
compiled_model_path
=
os
.
path
.
join
(
model_name_or_path
,
"neuron-compiled-artifacts"
,
hashed_config
)
shutil
.
rmtree
(
compiled_model_path
,
ignore_errors
=
True
)
else
:
compiled_model_path
=
os
.
path
.
join
(
"local-models"
,
model_name_or_path
,
"neuron-compiled-artifacts"
,
hashed_config
)
shutil
.
rmtree
(
compiled_model_path
,
ignore_errors
=
True
)
try
:
self
.
model
=
neuronx_model_cls
(
compiled_model_path
)
override_neuron_config
=
kwargs
[
"override_neuron_config"
]
for
k
,
v
in
override_neuron_config
.
items
():
setattr
(
self
.
model
.
config
.
neuron_config
,
k
,
v
)
self
.
model
.
load
(
compiled_model_path
)
return
except
(
FileNotFoundError
,
ValueError
)
as
e
:
logger
.
warning
(
"Exception: %s"
,
e
)
logger
.
warning
(
"Failed to load the model from %s, Recompiling..."
,
compiled_model_path
)
if
not
os
.
path
.
exists
(
model_name_or_path
):
hf_model
=
AutoModelForCausalLM
.
from_pretrained
(
model_name_or_path
)
saved_path
=
os
.
path
.
join
(
"local-models"
,
model_name_or_path
)
hf_model
.
save_pretrained
(
saved_path
)
model_name_or_path
=
saved_path
self
.
model
=
neuronx_model_cls
(
model_name_or_path
,
config
)
self
.
model
.
compile
(
compiled_model_path
)
self
.
model
.
load
(
compiled_model_path
)
class
NeuronMllamaForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
on_device_sampling_disabled
:
bool
=
False
)
->
None
:
super
().
__init__
()
# has_image is the only multimodal input that is used in
# token-generation
# This is a cache (on CPU) that saves has_image data per sequence id
# The number of entries in this cache is <= Batch-Size
self
.
has_image_cache
:
dict
[
int
,
torch
.
Tensor
]
=
{}
self
.
config
=
config
self
.
logits_processor
=
LogitsProcessor
(
config
.
get_text_config
().
vocab_size
,
logits_as_input
=
True
)
self
.
on_device_sampling_disabled
=
on_device_sampling_disabled
if
self
.
on_device_sampling_disabled
:
# Use default sampler
self
.
sampler
=
Sampler
()
# Lazy initialized
self
.
model
:
nn
.
Module
self
.
is_reorder_needed
:
bool
=
True
def
read_from_has_image_cache
(
self
,
seq_ids
:
torch
.
Tensor
):
has_image_list
=
[]
for
index
in
range
(
len
(
seq_ids
)):
seq_id
=
seq_ids
[
index
].
item
()
if
seq_id
in
self
.
has_image_cache
:
has_image_list
.
append
(
self
.
has_image_cache
[
seq_id
])
else
:
has_image_list
.
append
(
torch
.
tensor
([
0
]))
return
torch
.
tensor
(
has_image_list
)
def
write_to_has_image_cache
(
self
,
seq_ids
:
torch
.
Tensor
,
has_image
:
torch
.
Tensor
):
for
index
in
range
(
len
(
seq_ids
)):
seq_id
=
seq_ids
[
index
].
item
()
if
index
<
len
(
has_image
):
self
.
has_image_cache
[
seq_id
]
=
has_image
[
index
]
else
:
self
.
has_image_cache
[
seq_id
]
=
torch
.
zeros
(
1
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
seq_ids
:
torch
.
Tensor
,
pixel_values
:
torch
.
Tensor
,
aspect_ratios
:
torch
.
Tensor
,
num_chunks
:
torch
.
Tensor
,
has_image
:
torch
.
Tensor
,
sampling_params
)
->
torch
.
Tensor
:
# We update the has_image cache during prefill
# and read the has_image cache during decode
if
input_ids
.
shape
[
-
1
]
>
1
:
# prefill
self
.
write_to_has_image_cache
(
seq_ids
,
has_image
)
else
:
has_image
=
self
.
read_from_has_image_cache
(
seq_ids
)
bs
=
input_ids
.
shape
[
0
]
num_chunks
=
torch
.
zeros
((
bs
,
1
))
aspect_ratios
=
torch
.
zeros
((
bs
,
1
,
2
))
input_block_ids
=
seq_ids
origin_input_block_ids
=
seq_ids
if
self
.
is_reorder_needed
:
# sort block ids sequentially for perf/neuron support reasons
input_block_ids
,
sorted_indices
=
torch
.
sort
(
input_block_ids
)
input_ids
=
torch
.
index_select
(
input_ids
,
0
,
sorted_indices
)
positions
=
torch
.
index_select
(
positions
,
0
,
sorted_indices
)
sampling_params
=
torch
.
index_select
(
sampling_params
,
0
,
sorted_indices
)
pixel_values
=
torch
.
index_select
(
pixel_values
,
0
,
sorted_indices
)
aspect_ratios
=
torch
.
index_select
(
aspect_ratios
,
0
,
sorted_indices
)
num_chunks
=
torch
.
index_select
(
num_chunks
,
0
,
sorted_indices
)
has_image
=
torch
.
index_select
(
has_image
,
0
,
sorted_indices
)
self
.
vision_mask
=
create_vision_mask
(
input_ids
,
self
.
vision_token_id
)
output
=
self
.
model
(
input_ids
.
to
(
torch
.
int32
),
attention_mask
=
None
,
position_ids
=
positions
.
to
(
torch
.
int32
),
seq_ids
=
seq_ids
.
flatten
().
to
(
torch
.
int32
),
pixel_values
=
pixel_values
.
to
(
self
.
config
.
vision_config
.
torch_dtype
),
aspect_ratios
=
aspect_ratios
.
to
(
torch
.
int32
),
vision_mask
=
self
.
vision_mask
.
to
(
torch
.
int32
),
sampling_params
=
sampling_params
,
num_chunks
=
num_chunks
.
to
(
torch
.
int32
),
has_image
=
has_image
.
to
(
torch
.
int32
),
)
if
self
.
config
.
neuron_config
.
on_device_sampling_config
:
output
=
output
.
hidden_states
else
:
output
=
output
.
logits
[:,
-
1
,
:]
if
self
.
is_reorder_needed
and
origin_input_block_ids
.
shape
[
0
]
!=
1
:
restored_indices
=
torch
.
argsort
(
sorted_indices
)
output
=
torch
.
index_select
(
output
,
0
,
restored_indices
)
return
output
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
None
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
hidden_states
,
sampling_metadata
):
if
not
self
.
on_device_sampling_disabled
:
with
torch
.
profiler
.
record_function
(
"sample"
):
hidden_states
=
hidden_states
.
flatten
()
res
=
[]
sample_idx
=
0
for
seq_group
in
sampling_metadata
.
seq_groups
:
seq_ids
=
seq_group
.
seq_ids
samples
=
[]
for
seq_id
in
seq_ids
:
token_id
=
hidden_states
[
sample_idx
].
item
()
samples
.
append
(
SequenceOutput
(
parent_seq_id
=
seq_id
,
output_token
=
token_id
,
logprobs
=
{
token_id
:
Logprob
(
token_id
)}))
sample_idx
+=
1
res
.
append
(
CompletionSequenceGroupOutput
(
samples
=
samples
,
prompt_logprobs
=
None
))
next_tokens
=
SamplerOutput
(
outputs
=
res
)
else
:
next_tokens
=
self
.
sampler
(
None
,
hidden_states
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
model_name_or_path
:
str
,
**
kwargs
):
arch
=
_get_model_architecture
(
self
.
config
)
neuronx_module_path
,
neuronx_model_cls_name
=
(
_NEURON_SUPPORTED_MODELS
[
arch
])
neuronx_module
=
importlib
.
import_module
(
neuronx_module_path
)
neuronx_model_cls
=
getattr
(
neuronx_module
,
neuronx_model_cls_name
)
neuron_config
=
neuronx_model_cls
.
get_neuron_config_cls
()(
**
kwargs
[
'neuron_config'
])
self
.
config
.
neuron_config
=
neuron_config
logger
.
info
(
"neuron_config buckets: %s"
,
self
.
config
.
neuron_config
.
buckets
)
config
=
neuronx_model_cls
.
get_config_cls
()(
neuron_config
,
load_config
=
load_pretrained_config
(
model_name_or_path
))
hashed_config
=
hashlib
.
md5
(
config
.
to_json_string
().
encode
(
'utf-8'
),
usedforsecurity
=
False
).
hexdigest
()
if
os
.
getenv
(
"NEURON_COMPILED_ARTIFACTS"
)
is
not
None
:
compiled_model_path
=
os
.
getenv
(
"NEURON_COMPILED_ARTIFACTS"
)
elif
os
.
path
.
exists
(
model_name_or_path
):
compiled_model_path
=
os
.
path
.
join
(
model_name_or_path
,
"neuron-compiled-artifacts"
,
hashed_config
)
else
:
compiled_model_path
=
os
.
path
.
join
(
"local-models"
,
model_name_or_path
,
"neuron-compiled-artifacts"
,
hashed_config
)
try
:
self
.
model
=
neuronx_model_cls
(
compiled_model_path
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name_or_path
)
self
.
vision_token_id
=
tokenizer
(
"<|image|>"
,
add_special_tokens
=
False
).
input_ids
[
0
]
self
.
model
.
load
(
compiled_model_path
)
return
except
(
FileNotFoundError
,
ValueError
):
logger
.
warning
(
"Failed to load the model from %s, Recompiling..."
,
compiled_model_path
)
if
not
os
.
path
.
exists
(
model_name_or_path
):
hf_model
=
AutoModelForCausalLM
.
from_pretrained
(
model_name_or_path
)
saved_path
=
os
.
path
.
join
(
"local-models"
,
model_name_or_path
)
hf_model
.
save_pretrained
(
saved_path
)
model_name_or_path
=
saved_path
self
.
model
=
neuronx_model_cls
(
model_name_or_path
,
config
)
logger
.
info
(
"
\n
Compiling and saving model to %s"
,
model_name_or_path
)
p
=
multiprocessing
.
Process
(
target
=
compile_model
,
args
=
(
self
,
compiled_model_path
))
p
.
start
()
p
.
join
()
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name_or_path
)
tokenizer
.
save_pretrained
(
compiled_model_path
)
logger
.
info
(
"Successfully compiled and saved the model in %s"
,
compiled_model_path
)
# Read "<|image|>" token_id from the tokenizer
self
.
vision_token_id
=
tokenizer
(
"<|image|>"
,
add_special_tokens
=
False
).
input_ids
[
0
]
logger
.
info
(
"
\n
Loading model from compiled checkpoint..."
)
self
.
model
.
load
(
compiled_model_path
)
def
compile_model
(
neuron_model
,
traced_model_path
):
neuron_model
.
model
.
compile
(
traced_model_path
)
class
NeuronSpeculationCausalLM
(
nn
.
Module
):
"""A Neuron-optimized causal language model with speculative decoding."""
def
__init__
(
self
,
config
:
PretrainedConfig
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
,
logits_as_input
=
True
)
# Lazy initialized
self
.
model
:
nn
.
Module
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_block_ids
:
torch
.
Tensor
,
sampling_params
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
# sort block ids sequentially for perf/neuron support reasons
sorted_input_block_ids
,
sorted_indices
=
torch
.
sort
(
input_block_ids
)
input_ids
=
torch
.
index_select
(
input_ids
,
0
,
sorted_indices
)
positions
=
torch
.
index_select
(
positions
,
0
,
sorted_indices
)
sampling_params
=
torch
.
index_select
(
sampling_params
,
0
,
sorted_indices
)
output
=
self
.
model
(
input_ids
,
attention_mask
=
None
,
position_ids
=
positions
,
seq_ids
=
sorted_input_block_ids
,
sampling_params
=
sampling_params
)
restored_indices
=
torch
.
argsort
(
sorted_indices
)
# CTX encoding
if
(
positions
[:,
0
]).
sum
().
item
()
==
0
:
output
=
output
.
fused_outputs
[
0
][:,
0
:
1
]
if
input_block_ids
.
shape
[
0
]
!=
1
:
output
=
torch
.
index_select
(
output
,
0
,
restored_indices
)
return
output
# Fused Spec (Generation)
accepted_tokens_with_padding
=
output
.
fused_outputs
[
0
]
next_pos_ids
=
output
.
fused_outputs
[
-
1
]
generated_token_counts
=
next_pos_ids
-
positions
assert
torch
.
any
(
generated_token_counts
==
0
).
item
()
is
False
,
\
"NxDI model generated no output for one or more sequences."
batch_size
,
steps
=
accepted_tokens_with_padding
.
shape
mask
=
torch
.
arange
(
steps
).
expand
(
batch_size
,
-
1
)
>=
generated_token_counts
accepted_tokens_with_padding
[
mask
]
=
-
1
if
input_block_ids
.
shape
[
0
]
!=
1
:
accepted_tokens_with_padding
=
torch
.
index_select
(
accepted_tokens_with_padding
,
0
,
restored_indices
)
return
accepted_tokens_with_padding
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
list
[
SamplerOutput
]]:
batch_size
,
num_steps
=
logits
.
shape
seq_ids
=
[
seq_id
for
sg
in
sampling_metadata
.
seq_groups
for
seq_id
in
sg
.
seq_ids
]
# Organize input tensors by step instead of by sequence.
accepted_token_ids_by_step
=
logits
.
transpose
(
0
,
1
)
accepted_token_ids_by_step
=
accepted_token_ids_by_step
.
tolist
()
sampler_output_list
=
[]
for
step_index
in
range
(
num_steps
):
if
all
(
token_id
==
-
1
for
token_id
in
accepted_token_ids_by_step
[
step_index
]):
break
step_output_token_ids
=
[]
for
sequence_index
in
range
(
batch_size
):
token_id
=
accepted_token_ids_by_step
[
step_index
][
sequence_index
]
step_output_token_ids
.
append
(
CompletionSequenceGroupOutput
(
samples
=
[
SequenceOutput
(
parent_seq_id
=
seq_ids
[
sequence_index
],
output_token
=
token_id
,
logprobs
=
{
token_id
:
Logprob
(
token_id
)})
],
prompt_logprobs
=
None
))
sampler_output_list
.
append
(
SamplerOutput
(
outputs
=
step_output_token_ids
))
return
sampler_output_list
def
load_weights
(
self
,
model_name_or_path
:
str
,
draft_model_name_or_path
:
str
,
**
kwargs
):
arch
=
_get_model_architecture
(
self
.
config
)
neuronx_module_path
,
neuronx_model_cls_name
=
(
_NEURON_SUPPORTED_MODELS
[
arch
])
neuronx_module
=
importlib
.
import_module
(
neuronx_module_path
)
neuronx_model_cls
=
getattr
(
neuronx_module
,
neuronx_model_cls_name
)
neuron_config
=
neuronx_model_cls
.
get_neuron_config_cls
()(
**
kwargs
[
'neuron_config'
])
config
=
neuronx_model_cls
.
get_config_cls
()(
neuron_config
,
load_config
=
load_pretrained_config
(
model_name_or_path
))
draft_neuron_config
=
copy
.
deepcopy
(
config
.
neuron_config
)
if
not
config
.
neuron_config
.
enable_eagle_speculation
:
draft_neuron_config
.
speculation_length
=
0
draft_neuron_config
.
trace_tokengen_model
=
True
draft_neuron_config
.
enable_fused_speculation
=
False
if
getattr
(
config
.
neuron_config
,
"draft_model_modules_to_not_convert"
,
None
):
draft_neuron_config
.
modules_to_not_convert
=
(
draft_neuron_config
.
draft_model_modules_to_not_convert
)
if
config
.
neuron_config
.
enable_eagle_speculation
:
draft_neuron_config
.
is_eagle_draft
=
True
draft_neuron_config
.
sequence_parallel_enabled
=
False
draft_config
=
neuronx_model_cls
.
get_config_cls
()(
draft_neuron_config
,
load_config
=
load_pretrained_config
(
draft_model_name_or_path
))
fused_spec_config
=
(
FusedSpecNeuronConfig
(
neuronx_model_cls
.
_model_cls
,
draft_config
=
draft_config
,
draft_model_path
=
draft_model_name_or_path
))
config
.
fused_spec_config
=
fused_spec_config
self
.
config
.
neuron_config
=
neuron_config
hashed_config
=
hashlib
.
md5
(
config
.
to_json_string
().
encode
(
'utf-8'
),
usedforsecurity
=
False
).
hexdigest
()
if
os
.
getenv
(
"NEURON_COMPILED_ARTIFACTS"
)
is
not
None
:
compiled_model_path
=
os
.
getenv
(
"NEURON_COMPILED_ARTIFACTS"
)
elif
os
.
path
.
exists
(
model_name_or_path
):
compiled_model_path
=
os
.
path
.
join
(
model_name_or_path
,
"neuron-compiled-artifacts"
,
hashed_config
)
shutil
.
rmtree
(
compiled_model_path
,
ignore_errors
=
True
)
else
:
compiled_model_path
=
os
.
path
.
join
(
"local-models"
,
model_name_or_path
,
"neuron-compiled-artifacts"
,
hashed_config
)
shutil
.
rmtree
(
compiled_model_path
,
ignore_errors
=
True
)
try
:
self
.
model
=
neuronx_model_cls
(
compiled_model_path
)
override_neuron_config
=
kwargs
[
"override_neuron_config"
]
for
k
,
v
in
override_neuron_config
.
items
():
setattr
(
self
.
model
.
config
.
neuron_config
,
k
,
v
)
self
.
model
.
load
(
compiled_model_path
)
return
except
(
FileNotFoundError
,
ValueError
)
as
e
:
logger
.
warning
(
"Exception: %s"
,
e
)
logger
.
warning
(
"Failed to load the model from %s Recompiling..."
,
compiled_model_path
)
if
not
os
.
path
.
exists
(
model_name_or_path
):
hf_model
=
AutoModelForCausalLM
.
from_pretrained
(
model_name_or_path
)
saved_path
=
os
.
path
.
join
(
"local-models"
,
model_name_or_path
)
hf_model
.
save_pretrained
(
saved_path
)
model_name_or_path
=
saved_path
if
not
os
.
path
.
exists
(
draft_model_name_or_path
):
if
draft_model_name_or_path
!=
model_name_or_path
:
hf_model
=
AutoModelForCausalLM
.
from_pretrained
(
draft_model_name_or_path
)
saved_path
=
os
.
path
.
join
(
"local-models"
,
draft_model_name_or_path
)
hf_model
.
save_pretrained
(
saved_path
)
draft_model_name_or_path
=
saved_path
else
:
draft_model_name_or_path
=
model_name_or_path
config
.
fused_spec_config
.
draft_model_path
=
draft_model_name_or_path
self
.
model
=
neuronx_model_cls
(
model_name_or_path
,
config
)
self
.
model
.
compile
(
compiled_model_path
)
self
.
model
.
load
(
compiled_model_path
)
def
_get_model_architecture
(
config
:
PretrainedConfig
)
->
str
:
architectures
=
getattr
(
config
,
"architectures"
,
[])
for
arch
in
architectures
:
if
arch
in
_NEURON_SUPPORTED_MODELS
:
return
arch
raise
ValueError
(
f
"Model architectures
{
architectures
}
are not supported on Neuron "
f
"for now. Supported architectures: "
f
"
{
list
(
_NEURON_SUPPORTED_MODELS
.
keys
())
}
"
)
def
_get_default_neuron_config
(
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
lora_serving_config
:
LoraServingConfig
):
"""Generate a neuron config based on vllm config args."""
on_device_sampling_config
=
OnDeviceSamplingConfig
(
dynamic
=
True
,
deterministic
=
False
)
batch_size
=
scheduler_config
.
max_num_seqs
neuron_config
=
dict
(
tp_degree
=
parallel_config
.
tensor_parallel_size
,
ctx_batch_size
=
1
,
batch_size
=
batch_size
,
max_context_length
=
scheduler_config
.
max_model_len
,
seq_len
=
scheduler_config
.
max_model_len
,
enable_bucketing
=
True
,
is_continuous_batching
=
True
,
quantized
=
False
,
torch_dtype
=
TORCH_DTYPE_TO_NEURON_AMP
[
model_config
.
dtype
],
padding_side
=
"right"
,
on_device_sampling_config
=
on_device_sampling_config
,
sequence_parallel_enabled
=
True
,
lora_serving_config
=
lora_serving_config
)
return
neuron_config
def
_get_default_speculation_config
(
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
speculation_config
:
SpeculativeConfig
):
"""Generate a neuron config for speculative decoding based on vllm config
args."""
neuron_config
=
dict
(
tp_degree
=
parallel_config
.
tensor_parallel_size
,
ctx_batch_size
=
1
,
batch_size
=
scheduler_config
.
max_num_seqs
,
max_context_length
=
scheduler_config
.
max_model_len
,
seq_len
=
scheduler_config
.
max_model_len
,
speculation_length
=
speculation_config
.
num_speculative_tokens
,
trace_tokengen_model
=
False
,
enable_fused_speculation
=
True
,
enable_bucketing
=
True
,
is_continuous_batching
=
True
,
quantized
=
False
,
torch_dtype
=
TORCH_DTYPE_TO_NEURON_AMP
[
model_config
.
dtype
],
on_device_sampling_config
=
dict
(
top_k
=
1
,
do_sample
=
False
,
))
return
neuron_config
def
_get_neuron_config_after_override
(
default_neuron_config
,
overridden_neuron_config
):
"""Update default neuron config values with override args"""
overridden_neuron_config
=
overridden_neuron_config
or
{}
default_neuron_config
.
update
(
overridden_neuron_config
)
return
default_neuron_config
def
get_neuron_model
(
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
lora_serving_config
:
LoraServingConfig
)
->
nn
.
Module
:
"""Initializes a neuron-optimized model for inference."""
model_arch
=
_get_model_architecture
(
model_config
.
hf_config
)
if
model_arch
==
"MllamaForConditionalGeneration"
:
model
=
NeuronMllamaForCausalLM
(
model_config
.
hf_config
)
else
:
model
=
NeuronCausalLM
(
model_config
.
hf_config
)
default_neuron_config_args
=
_get_default_neuron_config
(
model_config
,
parallel_config
,
scheduler_config
,
lora_serving_config
)
neuron_config
=
_get_neuron_config_after_override
(
default_neuron_config_args
,
model_config
.
override_neuron_config
)
override_neuron_config
=
model_config
.
override_neuron_config
model
.
load_weights
(
model_config
.
model
,
neuron_config
=
neuron_config
,
override_neuron_config
=
override_neuron_config
)
return
model
.
eval
()
def
get_neuron_speculation_model
(
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
speculation_config
:
SpeculativeConfig
):
"""Initializes a neuron-optimized speculation model for inference.
This model handles speculation using both a draft model and an EAGLE draft.
"""
model
=
NeuronSpeculationCausalLM
(
model_config
.
hf_config
)
default_neuron_config_args
=
_get_default_speculation_config
(
model_config
,
parallel_config
,
scheduler_config
,
speculation_config
)
neuron_config
=
_get_neuron_config_after_override
(
default_neuron_config_args
,
model_config
.
override_neuron_config
)
override_neuron_config
=
model_config
.
override_neuron_config
model
.
load_weights
(
model_config
.
model
,
speculation_config
.
draft_model_config
.
model
,
neuron_config
=
neuron_config
,
override_neuron_config
=
override_neuron_config
)
return
model
.
eval
()
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment