Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4172235a
Unverified
Commit
4172235a
authored
Sep 06, 2025
by
Woosuk Kwon
Committed by
GitHub
Sep 06, 2025
Browse files
[V0 deprecation] Deprecate V0 Neuron backend (#21159)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
848562bd
Changes
46
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2 additions
and
1624 deletions
+2
-1624
.buildkite/release-pipeline.yaml
.buildkite/release-pipeline.yaml
+0
-16
.buildkite/scripts/hardware_ci/run-neuron-test.sh
.buildkite/scripts/hardware_ci/run-neuron-test.sh
+0
-64
MANIFEST.in
MANIFEST.in
+0
-1
docker/Dockerfile.neuron
docker/Dockerfile.neuron
+0
-56
examples/offline_inference/neuron.py
examples/offline_inference/neuron.py
+0
-49
examples/offline_inference/neuron_eagle.py
examples/offline_inference/neuron_eagle.py
+0
-61
examples/offline_inference/neuron_int8_quantization.py
examples/offline_inference/neuron_int8_quantization.py
+0
-63
examples/offline_inference/neuron_multimodal.py
examples/offline_inference/neuron_multimodal.py
+0
-110
examples/offline_inference/neuron_speculation.py
examples/offline_inference/neuron_speculation.py
+0
-64
requirements/neuron.txt
requirements/neuron.txt
+0
-9
setup.py
setup.py
+2
-34
tests/engine/test_arg_utils.py
tests/engine/test_arg_utils.py
+0
-9
tests/neuron/1_core/test_activation.py
tests/neuron/1_core/test_activation.py
+0
-43
tests/neuron/1_core/test_block_table.py
tests/neuron/1_core/test_block_table.py
+0
-154
tests/neuron/1_core/test_cache.py
tests/neuron/1_core/test_cache.py
+0
-86
tests/neuron/1_core/test_layernorm.py
tests/neuron/1_core/test_layernorm.py
+0
-57
tests/neuron/1_core/test_logits_processor.py
tests/neuron/1_core/test_logits_processor.py
+0
-95
tests/neuron/1_core/test_neuron_model_runner.py
tests/neuron/1_core/test_neuron_model_runner.py
+0
-127
tests/neuron/1_core/test_neuron_quant.py
tests/neuron/1_core/test_neuron_quant.py
+0
-12
tests/neuron/1_core/test_prefix_prefill.py
tests/neuron/1_core/test_prefix_prefill.py
+0
-514
No files found.
.buildkite/release-pipeline.yaml
View file @
4172235a
...
@@ -149,19 +149,3 @@ steps:
...
@@ -149,19 +149,3 @@ steps:
-
"
docker
push
public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent
meta-data
get
release-version)"
-
"
docker
push
public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent
meta-data
get
release-version)"
env
:
env
:
DOCKER_BUILDKIT
:
"
1"
DOCKER_BUILDKIT
:
"
1"
-
block
:
"
Build
Neuron
release
image"
key
:
block-neuron-release-image-build
depends_on
:
~
-
label
:
"
Build
and
publish
Neuron
release
image"
depends_on
:
block-neuron-release-image-build
agents
:
queue
:
neuron-postmerge
commands
:
-
"
aws
ecr-public
get-login-password
--region
us-east-1
|
docker
login
--username
AWS
--password-stdin
public.ecr.aws/q9t5s3a7"
-
"
DOCKER_BUILDKIT=1
docker
build
--build-arg
max_jobs=16
--build-arg
GIT_REPO_CHECK=1
--tag
public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:$(buildkite-agent
meta-data
get
release-version)
--tag
public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:latest
--progress
plain
-f
docker/Dockerfile.neuron
."
-
"
docker
push
public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:latest"
-
"
docker
push
public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:$(buildkite-agent
meta-data
get
release-version)"
env
:
DOCKER_BUILDKIT
:
"
1"
.buildkite/scripts/hardware_ci/run-neuron-test.sh
deleted
100644 → 0
View file @
848562bd
#!/bin/bash
# This script build the Neuron docker image and run the API server inside the container.
# It serves a sanity check for compilation and basic model usage.
set
-e
set
-v
image_name
=
"neuron/vllm-ci"
container_name
=
"neuron_
$(
tr
-dc
A-Za-z0-9 < /dev/urandom |
head
-c
10
;
echo
)
"
HF_CACHE
=
"
$(
realpath
~
)
/huggingface"
mkdir
-p
"
${
HF_CACHE
}
"
HF_MOUNT
=
"/root/.cache/huggingface"
HF_TOKEN
=
$(
aws secretsmanager get-secret-value
--secret-id
"ci/vllm-neuron/hf-token"
--region
us-west-2
--query
'SecretString'
--output
text | jq
-r
.VLLM_NEURON_CI_HF_TOKEN
)
NEURON_COMPILE_CACHE_URL
=
"
$(
realpath
~
)
/neuron_compile_cache"
mkdir
-p
"
${
NEURON_COMPILE_CACHE_URL
}
"
NEURON_COMPILE_CACHE_MOUNT
=
"/root/.cache/neuron_compile_cache"
# Try building the docker image
aws ecr-public get-login-password
--region
us-east-1 | docker login
--username
AWS
--password-stdin
public.ecr.aws
# prune old image and containers to save disk space, and only once a day
# by using a timestamp file in tmp.
if
[
-f
/tmp/neuron-docker-build-timestamp
]
;
then
last_build
=
$(
cat
/tmp/neuron-docker-build-timestamp
)
current_time
=
$(
date
+%s
)
if
[
$((
current_time
-
last_build
))
-gt
86400
]
;
then
# Remove dangling images (those that are not tagged and not used by any container)
docker image prune
-f
# Remove unused volumes / force the system prune for old images as well.
docker volume prune
-f
&&
docker system prune
-f
echo
"
$current_time
"
>
/tmp/neuron-docker-build-timestamp
fi
else
date
"+%s"
>
/tmp/neuron-docker-build-timestamp
fi
docker build
-t
"
${
image_name
}
"
-f
docker/Dockerfile.neuron
.
# Setup cleanup
remove_docker_container
()
{
docker image
rm
-f
"
${
image_name
}
"
||
true
;
}
trap
remove_docker_container EXIT
# Run the image
docker run
--rm
-it
--device
=
/dev/neuron0
--network
bridge
\
-v
"
${
HF_CACHE
}
:
${
HF_MOUNT
}
"
\
-e
"HF_HOME=
${
HF_MOUNT
}
"
\
-e
"HF_TOKEN=
${
HF_TOKEN
}
"
\
-v
"
${
NEURON_COMPILE_CACHE_URL
}
:
${
NEURON_COMPILE_CACHE_MOUNT
}
"
\
-e
"NEURON_COMPILE_CACHE_URL=
${
NEURON_COMPILE_CACHE_MOUNT
}
"
\
--name
"
${
container_name
}
"
\
${
image_name
}
\
/bin/bash
-c
"
set -e; # Exit on first error
python3 /workspace/vllm/examples/offline_inference/neuron.py;
python3 -m pytest /workspace/vllm/tests/neuron/1_core/ -v --capture=tee-sys;
for f in /workspace/vllm/tests/neuron/2_core/*.py; do
echo
\"
Running test file:
\$
f
\"
;
python3 -m pytest
\$
f -v --capture=tee-sys;
done
"
\ No newline at end of file
MANIFEST.in
View file @
4172235a
...
@@ -2,7 +2,6 @@ include LICENSE
...
@@ -2,7 +2,6 @@ include LICENSE
include requirements/common.txt
include requirements/common.txt
include requirements/cuda.txt
include requirements/cuda.txt
include requirements/rocm.txt
include requirements/rocm.txt
include requirements/neuron.txt
include requirements/cpu.txt
include requirements/cpu.txt
include CMakeLists.txt
include CMakeLists.txt
...
...
docker/Dockerfile.neuron
deleted
100644 → 0
View file @
848562bd
# default base image
# https://gallery.ecr.aws/neuron/pytorch-inference-neuronx
ARG BASE_IMAGE="public.ecr.aws/neuron/pytorch-inference-neuronx:2.6.0-neuronx-py310-sdk2.23.0-ubuntu22.04"
FROM $BASE_IMAGE
RUN echo "Base image is $BASE_IMAGE"
# Install some basic utilities
RUN apt-get update && \
apt-get install -y \
git \
python3 \
python3-pip \
ffmpeg libsm6 libxext6 libgl1
### Mount Point ###
# When launching the container, mount the code directory to /workspace
ARG APP_MOUNT=/workspace
VOLUME [ ${APP_MOUNT} ]
WORKDIR ${APP_MOUNT}/vllm
RUN python3 -m pip install --upgrade pip
RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas tenacity
RUN python3 -m pip install neuronx-cc==2.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U
RUN python3 -m pip install pytest
# uninstall transformers-neuronx package explicitly to avoid version conflict
RUN python3 -m pip uninstall -y transformers-neuronx
COPY . .
ARG GIT_REPO_CHECK=0
RUN --mount=type=bind,source=.git,target=.git \
if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi
RUN python3 -m pip install -U \
'cmake>=3.26.1' ninja packaging 'setuptools-scm>=8' wheel jinja2 \
-r requirements/neuron.txt
ENV VLLM_TARGET_DEVICE neuron
RUN --mount=type=bind,source=.git,target=.git \
pip install --no-build-isolation -v -e .
# install development dependencies (for testing)
RUN python3 -m pip install -e tests/vllm_test_utils
# install transformers-neuronx package as an optional dependencies (for V0)
# FIXME: `--no-deps` argument is temporarily added to resolve transformers package version conflict
RUN python3 -m pip install transformers-neuronx==0.13.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U --no-deps
RUN python3 -m pip install sentencepiece transformers==4.48.0 -U
# overwrite entrypoint to run bash script
RUN echo "import subprocess; import sys; subprocess.check_call(sys.argv[1:])" > /usr/local/bin/dockerd-entrypoint.py
CMD ["/bin/bash"]
examples/offline_inference/neuron.py
deleted
100644 → 0
View file @
848562bd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
vllm
import
LLM
,
SamplingParams
# Sample prompts.
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
# Create a sampling params object.
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
)
def
main
():
# Create an LLM.
llm
=
LLM
(
model
=
"TinyLlama/TinyLlama-1.1B-Chat-v1.0"
,
max_num_seqs
=
8
,
# The max_model_len and block_size arguments are required to be same as
# max sequence length when targeting neuron device.
# Currently, this is a known limitation in continuous batching support
# in transformers-neuronx.
# TODO(liangfu): Support paged-attention in transformers-neuronx.
max_model_len
=
1024
,
block_size
=
1024
,
# ruff: noqa: E501
# The device can be automatically detected when AWS Neuron SDK is installed.
# The device argument can be either unspecified for automated detection,
# or explicitly assigned.
device
=
"neuron"
,
tensor_parallel_size
=
2
,
)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
# Print the outputs.
print
(
"-"
*
50
)
for
output
in
outputs
:
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
print
(
f
"Prompt:
{
prompt
!
r
}
\n
Generated text:
{
generated_text
!
r
}
"
)
print
(
"-"
*
50
)
if
__name__
==
"__main__"
:
main
()
examples/offline_inference/neuron_eagle.py
deleted
100644 → 0
View file @
848562bd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This example shows how to run offline inference with an EAGLE speculative
decoding model on neuron. To use EAGLE speculative decoding, you must use
a draft model that is specifically fine-tuned for EAGLE speculation.
Additionally, to use EAGLE with NxD Inference, the draft model must include
the LM head weights from the target model. These weights are shared between
the draft and target model.
"""
from
vllm
import
LLM
,
SamplingParams
# Sample prompts.
prompts
=
[
"What is annapurna labs?"
,
]
def
main
():
# Create a sampling params object.
sampling_params
=
SamplingParams
(
top_k
=
1
,
max_tokens
=
500
,
ignore_eos
=
True
)
# Create an LLM.
llm
=
LLM
(
model
=
"/home/ubuntu/model_hf/Meta-Llama-3.1-70B-Instruct"
,
speculative_config
=
{
"model"
:
"/home/ubuntu/model_hf/Llama-3.1-70B-Instruct-EAGLE-Draft"
,
"num_speculative_tokens"
:
5
,
"max_model_len"
:
2048
,
},
max_num_seqs
=
4
,
# The max_model_len and block_size arguments are required to be same as
# max sequence length when targeting neuron device.
# Currently, this is a known limitation in continuous batching support
# in neuronx-distributed-inference.
max_model_len
=
2048
,
block_size
=
2048
,
# The device can be automatically detected when AWS Neuron SDK is installed.
# The device argument can be either unspecified for automated detection,
# or explicitly assigned.
device
=
"neuron"
,
tensor_parallel_size
=
32
,
override_neuron_config
=
{
"enable_eagle_speculation"
:
True
,
"enable_fused_speculation"
:
True
,
},
)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
# Print the outputs.
for
output
in
outputs
:
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
print
(
f
"Prompt:
{
prompt
!
r
}
,
\n\n\n
Generated text:
{
generated_text
!
r
}
"
)
if
__name__
==
"__main__"
:
main
()
examples/offline_inference/neuron_int8_quantization.py
deleted
100644 → 0
View file @
848562bd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
os
from
vllm
import
LLM
,
SamplingParams
# creates XLA hlo graphs for all the context length buckets.
os
.
environ
[
"NEURON_CONTEXT_LENGTH_BUCKETS"
]
=
"128,512,1024,2048"
# creates XLA hlo graphs for all the token gen buckets.
os
.
environ
[
"NEURON_TOKEN_GEN_BUCKETS"
]
=
"128,512,1024,2048"
# Quantizes neuron model weight to int8 ,
# The default config for quantization is int8 dtype.
os
.
environ
[
"NEURON_QUANT_DTYPE"
]
=
"s8"
# Sample prompts.
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
# Create a sampling params object.
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
)
def
main
():
# Create an LLM.
llm
=
LLM
(
model
=
"TinyLlama/TinyLlama-1.1B-Chat-v1.0"
,
max_num_seqs
=
8
,
# The max_model_len and block_size arguments are required to be same as
# max sequence length when targeting neuron device.
# Currently, this is a known limitation in continuous batching support
# in transformers-neuronx.
# TODO(liangfu): Support paged-attention in transformers-neuronx.
max_model_len
=
2048
,
block_size
=
2048
,
# ruff: noqa: E501
# The device can be automatically detected when AWS Neuron SDK is installed.
# The device argument can be either unspecified for automated detection,
# or explicitly assigned.
device
=
"neuron"
,
quantization
=
"neuron_quant"
,
override_neuron_config
=
{
"cast_logits_dtype"
:
"bfloat16"
,
},
tensor_parallel_size
=
2
,
)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
# Print the outputs.
print
(
"-"
*
50
)
for
output
in
outputs
:
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
print
(
f
"Prompt:
{
prompt
!
r
}
\n
Generated text:
{
generated_text
!
r
}
"
)
print
(
"-"
*
50
)
if
__name__
==
"__main__"
:
main
()
examples/offline_inference/neuron_multimodal.py
deleted
100644 → 0
View file @
848562bd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
requests
import
torch
from
neuronx_distributed_inference.models.mllama.utils
import
add_instruct
from
PIL
import
Image
from
vllm
import
LLM
,
SamplingParams
,
TextPrompt
def
get_image
(
image_url
):
image
=
Image
.
open
(
requests
.
get
(
image_url
,
stream
=
True
).
raw
)
return
image
# Model Inputs
PROMPTS
=
[
"What is in this image? Tell me a story"
,
"What is the recipe of mayonnaise in two sentences?"
,
"Describe this image"
,
"What is the capital of Italy famous for?"
,
]
IMAGES
=
[
get_image
(
"https://images.pexels.com/photos/1108099/pexels-photo-1108099.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=500"
),
None
,
get_image
(
"https://images.pexels.com/photos/1108099/pexels-photo-1108099.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=500"
),
None
,
]
SAMPLING_PARAMS
=
[
dict
(
top_k
=
1
,
temperature
=
1.0
,
top_p
=
1.0
,
max_tokens
=
16
)
for
_
in
range
(
len
(
PROMPTS
))
]
def
get_VLLM_mllama_model_inputs
(
prompt
,
single_image
,
sampling_params
):
# Prepare all inputs for mllama generation, including:
# 1. put text prompt into instruct chat template
# 2. compose single text and single image prompt into Vllm's prompt class
# 3. prepare sampling parameters
input_image
=
single_image
has_image
=
torch
.
tensor
([
1
])
if
isinstance
(
single_image
,
torch
.
Tensor
)
and
single_image
.
numel
()
==
0
:
has_image
=
torch
.
tensor
([
0
])
instruct_prompt
=
add_instruct
(
prompt
,
has_image
)
inputs
=
TextPrompt
(
prompt
=
instruct_prompt
)
if
input_image
is
not
None
:
inputs
[
"multi_modal_data"
]
=
{
"image"
:
input_image
}
sampling_params
=
SamplingParams
(
**
sampling_params
)
return
inputs
,
sampling_params
def
print_outputs
(
outputs
):
# Print the outputs.
for
output
in
outputs
:
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
def
main
():
assert
(
len
(
PROMPTS
)
==
len
(
IMAGES
)
==
len
(
SAMPLING_PARAMS
)
),
f
"""Text, image prompts and sampling parameters should have the
same batch size; but got
{
len
(
PROMPTS
)
}
,
{
len
(
IMAGES
)
}
,
and
{
len
(
SAMPLING_PARAMS
)
}
"""
# Create an LLM.
llm
=
LLM
(
model
=
"meta-llama/Llama-3.2-11B-Vision-Instruct"
,
max_num_seqs
=
1
,
max_model_len
=
4096
,
block_size
=
4096
,
device
=
"neuron"
,
tensor_parallel_size
=
32
,
override_neuron_config
=
{
"sequence_parallel_enabled"
:
False
,
"skip_warmup"
:
True
,
"save_sharded_checkpoint"
:
True
,
"on_device_sampling_config"
:
{
"global_topk"
:
1
,
"dynamic"
:
False
,
"deterministic"
:
False
,
},
},
)
batched_inputs
=
[]
batched_sample_params
=
[]
for
pmpt
,
img
,
params
in
zip
(
PROMPTS
,
IMAGES
,
SAMPLING_PARAMS
):
inputs
,
sampling_params
=
get_VLLM_mllama_model_inputs
(
pmpt
,
img
,
params
)
# test batch-size = 1
outputs
=
llm
.
generate
(
inputs
,
sampling_params
)
print_outputs
(
outputs
)
batched_inputs
.
append
(
inputs
)
batched_sample_params
.
append
(
sampling_params
)
# test batch-size = 4
outputs
=
llm
.
generate
(
batched_inputs
,
batched_sample_params
)
print_outputs
(
outputs
)
if
__name__
==
"__main__"
:
main
()
examples/offline_inference/neuron_speculation.py
deleted
100644 → 0
View file @
848562bd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This example shows how to run offline inference with a speculative
decoding model on neuron.
"""
import
os
from
vllm
import
LLM
,
SamplingParams
# Sample prompts.
prompts
=
[
"Hello, I am a language model and I can help"
,
"The president of the United States is"
,
"The capital of France is"
,
]
def
config_buckets
():
"""Configure context length and token gen buckets."""
# creates XLA hlo graphs for all the context length buckets.
os
.
environ
[
"NEURON_CONTEXT_LENGTH_BUCKETS"
]
=
"128,512,1024,2048"
# creates XLA hlo graphs for all the token gen buckets.
os
.
environ
[
"NEURON_TOKEN_GEN_BUCKETS"
]
=
"128,512,1024,2048"
def
initialize_llm
():
"""Create an LLM with speculative decoding."""
return
LLM
(
model
=
"openlm-research/open_llama_7b"
,
speculative_config
=
{
"model"
:
"openlm-research/open_llama_3b"
,
"num_speculative_tokens"
:
4
,
"max_model_len"
:
2048
,
},
max_num_seqs
=
4
,
max_model_len
=
2048
,
block_size
=
2048
,
device
=
"neuron"
,
tensor_parallel_size
=
32
,
)
def
process_requests
(
llm
:
LLM
,
sampling_params
:
SamplingParams
):
"""Generate texts from prompts and print them."""
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
for
output
in
outputs
:
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
def
main
():
"""Main function that sets up the llm and processes prompts."""
config_buckets
()
llm
=
initialize_llm
()
# Create a sampling params object.
sampling_params
=
SamplingParams
(
max_tokens
=
100
,
top_k
=
1
)
process_requests
(
llm
,
sampling_params
)
if
__name__
==
"__main__"
:
main
()
requirements/neuron.txt
deleted
100644 → 0
View file @
848562bd
# Common dependencies
-r common.txt
# Dependencies for Neuron devices
packaging>=24.2
setuptools>=77.0.3,<80.0.0
torch-neuronx >= 2.5.0
neuronx-cc>=2.0.0a0
torchvision # Required for Llama3.2 multimodal image preprocessing
setup.py
View file @
4172235a
...
@@ -413,8 +413,7 @@ def _no_device() -> bool:
...
@@ -413,8 +413,7 @@ def _no_device() -> bool:
def
_is_cuda
()
->
bool
:
def
_is_cuda
()
->
bool
:
has_cuda
=
torch
.
version
.
cuda
is
not
None
has_cuda
=
torch
.
version
.
cuda
is
not
None
return
(
VLLM_TARGET_DEVICE
==
"cuda"
and
has_cuda
return
(
VLLM_TARGET_DEVICE
==
"cuda"
and
has_cuda
and
not
_is_tpu
())
and
not
(
_is_neuron
()
or
_is_tpu
()))
def
_is_hip
()
->
bool
:
def
_is_hip
()
->
bool
:
...
@@ -422,10 +421,6 @@ def _is_hip() -> bool:
...
@@ -422,10 +421,6 @@ def _is_hip() -> bool:
or
VLLM_TARGET_DEVICE
==
"rocm"
)
and
torch
.
version
.
hip
is
not
None
or
VLLM_TARGET_DEVICE
==
"rocm"
)
and
torch
.
version
.
hip
is
not
None
def
_is_neuron
()
->
bool
:
return
VLLM_TARGET_DEVICE
==
"neuron"
def
_is_tpu
()
->
bool
:
def
_is_tpu
()
->
bool
:
return
VLLM_TARGET_DEVICE
==
"tpu"
return
VLLM_TARGET_DEVICE
==
"tpu"
...
@@ -470,25 +465,6 @@ def get_rocm_version():
...
@@ -470,25 +465,6 @@ def get_rocm_version():
return
None
return
None
def
get_neuronxcc_version
():
import
sysconfig
site_dir
=
sysconfig
.
get_paths
()[
"purelib"
]
version_file
=
os
.
path
.
join
(
site_dir
,
"neuronxcc"
,
"version"
,
"__init__.py"
)
# Check if the command was executed successfully
with
open
(
version_file
)
as
fp
:
content
=
fp
.
read
()
# Extract the version using a regular expression
match
=
re
.
search
(
r
"__version__ = '(\S+)'"
,
content
)
if
match
:
# Return the version string
return
match
.
group
(
1
)
else
:
raise
RuntimeError
(
"Could not find Neuron version in the output"
)
def
get_nvcc_cuda_version
()
->
Version
:
def
get_nvcc_cuda_version
()
->
Version
:
"""Get the CUDA version from nvcc.
"""Get the CUDA version from nvcc.
...
@@ -541,12 +517,6 @@ def get_vllm_version() -> str:
...
@@ -541,12 +517,6 @@ def get_vllm_version() -> str:
rocm_version
=
get_rocm_version
()
or
torch
.
version
.
hip
rocm_version
=
get_rocm_version
()
or
torch
.
version
.
hip
if
rocm_version
and
rocm_version
!=
MAIN_CUDA_VERSION
:
if
rocm_version
and
rocm_version
!=
MAIN_CUDA_VERSION
:
version
+=
f
"
{
sep
}
rocm
{
rocm_version
.
replace
(
'.'
,
''
)[:
3
]
}
"
version
+=
f
"
{
sep
}
rocm
{
rocm_version
.
replace
(
'.'
,
''
)[:
3
]
}
"
elif
_is_neuron
():
# Get the Neuron version
neuron_version
=
str
(
get_neuronxcc_version
())
if
neuron_version
!=
MAIN_CUDA_VERSION
:
neuron_version_str
=
neuron_version
.
replace
(
"."
,
""
)[:
3
]
version
+=
f
"
{
sep
}
neuron
{
neuron_version_str
}
"
elif
_is_tpu
():
elif
_is_tpu
():
version
+=
f
"
{
sep
}
tpu"
version
+=
f
"
{
sep
}
tpu"
elif
_is_cpu
():
elif
_is_cpu
():
...
@@ -591,8 +561,6 @@ def get_requirements() -> list[str]:
...
@@ -591,8 +561,6 @@ def get_requirements() -> list[str]:
requirements
=
modified_requirements
requirements
=
modified_requirements
elif
_is_hip
():
elif
_is_hip
():
requirements
=
_read_requirements
(
"rocm.txt"
)
requirements
=
_read_requirements
(
"rocm.txt"
)
elif
_is_neuron
():
requirements
=
_read_requirements
(
"neuron.txt"
)
elif
_is_tpu
():
elif
_is_tpu
():
requirements
=
_read_requirements
(
"tpu.txt"
)
requirements
=
_read_requirements
(
"tpu.txt"
)
elif
_is_cpu
():
elif
_is_cpu
():
...
@@ -601,7 +569,7 @@ def get_requirements() -> list[str]:
...
@@ -601,7 +569,7 @@ def get_requirements() -> list[str]:
requirements
=
_read_requirements
(
"xpu.txt"
)
requirements
=
_read_requirements
(
"xpu.txt"
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
"Unsupported platform, please use CUDA, ROCm,
Neuron,
or CPU."
)
"Unsupported platform, please use CUDA, ROCm, or CPU."
)
return
requirements
return
requirements
...
...
tests/engine/test_arg_utils.py
View file @
4172235a
...
@@ -287,15 +287,6 @@ def test_prefix_cache_default():
...
@@ -287,15 +287,6 @@ def test_prefix_cache_default():
},
},
"mm-processor-kwargs"
"mm-processor-kwargs"
),
),
(
'{"cast_logits_dtype":"bfloat16","sequence_parallel_norm":true,"sequence_parallel_norm_threshold":2048}'
,
{
"cast_logits_dtype"
:
"bfloat16"
,
"sequence_parallel_norm"
:
True
,
"sequence_parallel_norm_threshold"
:
2048
,
},
"override-neuron-config"
),
])
])
# yapf: enable
# yapf: enable
def
test_composite_arg_parser
(
arg
,
expected
,
option
):
def
test_composite_arg_parser
(
arg
,
expected
,
option
):
...
...
tests/neuron/1_core/test_activation.py
deleted
100644 → 0
View file @
848562bd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
import
torch
import
torch.nn.functional
as
F
from
vllm.model_executor.layers.activation
import
FastGELU
,
SiluAndMul
from
vllm.platforms
import
current_platform
@
pytest
.
mark
.
parametrize
(
"activation"
,
[
"silu_and_mul"
,
"gelu_fast"
])
@
pytest
.
mark
.
parametrize
(
"num_tokens,d,dtype"
,
[
(
7
,
512
,
torch
.
half
),
(
7
,
512
,
torch
.
float
),
(
83
,
512
,
torch
.
half
),
])
@
torch
.
inference_mode
()
def
test_act_and_mul
(
activation
:
str
,
num_tokens
:
int
,
d
:
int
,
dtype
:
torch
.
dtype
,
)
->
None
:
import
torch_xla.core.xla_model
as
xm
device
=
xm
.
xla_device
()
current_platform
.
seed_everything
(
0
)
torch
.
set_default_device
(
"cpu"
)
x
=
torch
.
randn
(
num_tokens
,
2
*
d
,
dtype
=
dtype
).
to
(
device
=
device
)
if
activation
==
"silu_and_mul"
:
layer
=
SiluAndMul
()
fn
=
layer
.
forward_native
elif
activation
==
"gelu_fast"
:
layer
=
FastGELU
()
fn
=
F
.
gelu
else
:
raise
NotImplementedError
(
f
"activation
{
activation
}
is not implemented."
)
assert
x
.
is_xla
,
"input tensor under testing is expected to be XLA tensor."
out
=
layer
.
to
(
device
=
device
).
forward_neuron
(
x
)
ref_out
=
fn
(
x
.
cpu
())
torch
.
testing
.
assert_close
(
out
.
cpu
(),
ref_out
,
atol
=
0.01
,
rtol
=
0.0
)
tests/neuron/1_core/test_block_table.py
deleted
100644 → 0
View file @
848562bd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
neuronxcc.nki.language
as
nl
import
pytest
import
torch
import
torch.nn.functional
as
F
from
neuronxcc
import
nki
from
vllm.attention.ops.nki_flash_attn
import
(
load_block_tables
,
transform_block_tables_for_indirect_load
)
def
is_power_of_2
(
n
):
return
n
>
0
and
(
n
&
(
n
-
1
)
==
0
)
def
nki_load_and_transform_block_tables
(
block_tables
,
num_tiles
,
num_blocks_per_tile
,
num_head
,
head_id
,
block_size_tiling_factor
,
):
assert
is_power_of_2
(
num_blocks_per_tile
),
f
"
{
num_blocks_per_tile
=
}
must be power of 2"
block_tables_sbuf
=
load_block_tables
(
block_tables
,
num_tiles
,
num_blocks_per_tile
)
# we need to pass an Index as head_id
head_id
=
nl
.
arange
(
1
)[
None
,
:]
+
head_id
block_tables_transposed
=
transform_block_tables_for_indirect_load
(
block_tables_sbuf
,
block_size_tiling_factor
,
num_head
,
head_id
)
B_P_SIZE
=
128
assert
block_tables_transposed
.
shape
[
1
]
==
B_P_SIZE
out
=
nl
.
ndarray
(
block_tables_transposed
.
shape
,
dtype
=
nl
.
int32
,
buffer
=
nl
.
shared_hbm
,
)
for
i
in
nl
.
affine_range
(
block_tables_transposed
.
shape
[
0
]):
nl
.
store
(
dst
=
out
[
i
],
value
=
block_tables_transposed
[
i
])
return
out
def
ref_block_tables_transform
(
block_tables
,
num_tiles
,
num_blocks_per_tile
,
num_head
,
head_id
,
block_size_tiling_factor
,
):
assert
block_tables
.
numel
()
==
num_tiles
*
num_blocks_per_tile
block_tables
=
block_tables
.
view
(
num_tiles
,
num_blocks_per_tile
)
B_F_SIZE
=
128
num_tiles_padded
=
(
num_tiles
+
B_F_SIZE
-
1
)
//
B_F_SIZE
*
B_F_SIZE
block_tables
=
F
.
pad
(
block_tables
,
(
0
,
0
,
0
,
num_tiles_padded
-
num_tiles
),
"constant"
,
0
,
)
block_tables
=
block_tables
*
num_head
+
head_id
block_tables
=
block_tables
.
view
(
num_tiles_padded
,
num_blocks_per_tile
,
1
)
offset
=
torch
.
arange
(
0
,
block_size_tiling_factor
).
view
(
1
,
1
,
-
1
)
block_tables
=
block_tables
*
block_size_tiling_factor
+
offset
block_tables_transposed
=
block_tables
.
view
(
num_tiles_padded
,
-
1
).
t
()
num_blocks_per_tile
=
block_tables_transposed
.
shape
[
0
]
assert
num_blocks_per_tile
%
B_F_SIZE
==
0
return
block_tables_transposed
.
view
(
num_blocks_per_tile
//
B_F_SIZE
,
B_F_SIZE
,
num_tiles_padded
)
@
pytest
.
mark
.
parametrize
(
"q_head_per_kv_head,head_id"
,
[
(
1
,
0
),
(
3
,
1
),
],
)
@
pytest
.
mark
.
parametrize
(
"num_tiles,num_blocks_per_tile"
,
[
(
1
,
1
),
(
13
,
16
),
(
17
,
128
),
(
35
,
512
),
(
128
,
128
),
(
130
,
64
),
(
280
,
256
),
(
315
,
1
),
],
)
@
torch
.
inference_mode
()
def
test_load_and_transform_block_tables
(
monkeypatch
:
pytest
.
MonkeyPatch
,
num_tiles
,
num_blocks_per_tile
,
q_head_per_kv_head
,
head_id
,
)
->
None
:
import
torch_xla.core.xla_model
as
xm
device
=
xm
.
xla_device
()
compiler_flags_str
=
" "
.
join
([
"-O1"
,
"--retry_failed_compilation"
,
])
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"NEURON_CC_FLAGS"
,
compiler_flags_str
)
torch
.
manual_seed
(
10000
)
torch
.
set_printoptions
(
sci_mode
=
False
)
# On Neuron, we need B_P_SIZE = 128 blocks to make DMA efficient
B_P_SIZE
=
128
if
num_blocks_per_tile
<
B_P_SIZE
:
assert
B_P_SIZE
%
num_blocks_per_tile
==
0
block_size_tiling_factor
=
B_P_SIZE
//
num_blocks_per_tile
else
:
block_size_tiling_factor
=
1
max_num_blocks
=
100000
block_tables
=
torch
.
randint
(
0
,
max_num_blocks
,
(
num_tiles
*
num_blocks_per_tile
,
),
dtype
=
torch
.
int32
,
)
nki_out
=
nki
.
jit
(
nki_load_and_transform_block_tables
)[
1
,
1
](
block_tables
.
to
(
device
=
device
),
num_tiles
,
num_blocks_per_tile
,
q_head_per_kv_head
,
head_id
,
block_size_tiling_factor
,
).
cpu
()
ref_out
=
ref_block_tables_transform
(
block_tables
,
num_tiles
,
num_blocks_per_tile
,
q_head_per_kv_head
,
head_id
,
block_size_tiling_factor
,
)
assert
(
nki_out
.
shape
==
ref_out
.
shape
),
f
"
{
nki_out
.
shape
=
}
!=
{
ref_out
.
shape
=
}
"
assert
torch
.
all
(
nki_out
==
ref_out
)
tests/neuron/1_core/test_cache.py
deleted
100644 → 0
View file @
848562bd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
import
torch
from
vllm.attention.ops.nki_flash_attn
import
reshape_and_cache
@
pytest
.
mark
.
parametrize
(
"num_tokens, n_kv_head, d_head, num_blocks, block_size"
,
[
# Small model configuration (e.g., GPT-2 small)
(
32
,
12
,
64
,
4
,
128
),
# Typical sequence processing
(
1
,
12
,
64
,
4
,
128
),
# Single token update
(
128
,
12
,
64
,
4
,
128
),
# Longer sequence
# Medium model configuration (e.g., GPT-2 medium)
(
64
,
16
,
96
,
8
,
256
),
# Standard batch
(
256
,
16
,
96
,
8
,
256
),
# Large batch
# Large model configuration (e.g., GPT-3 style)
(
48
,
32
,
128
,
16
,
512
),
# Typical processing window
(
512
,
32
,
128
,
16
,
512
),
# Full context window
# Edge cases and stress tests
(
1024
,
8
,
32
,
32
,
32
),
# Many tokens, small heads
(
16
,
64
,
256
,
4
,
64
),
# Few tokens, many heads
(
2048
,
24
,
128
,
64
,
128
),
# Large scale test
# Minimal configurations for debugging
(
4
,
2
,
16
,
2
,
16
),
# Tiny test case
(
1
,
1
,
8
,
1
,
8
),
# Minimal possible
])
def
test_reshape_and_cache
(
num_tokens
,
n_kv_head
,
d_head
,
num_blocks
,
block_size
):
# Set random seed for reproducibility
torch
.
manual_seed
(
42
)
# Create CPU tensors for reference implementation
key_cpu
=
torch
.
randn
(
num_tokens
,
n_kv_head
,
d_head
)
/
torch
.
sqrt
(
torch
.
tensor
(
d_head
))
value_cpu
=
torch
.
randn
(
num_tokens
,
n_kv_head
,
d_head
)
/
torch
.
sqrt
(
torch
.
tensor
(
d_head
))
key_cache_cpu
=
torch
.
zeros
(
num_blocks
,
n_kv_head
,
block_size
,
d_head
)
value_cache_cpu
=
torch
.
zeros
(
num_blocks
,
n_kv_head
,
block_size
,
d_head
)
slot_mapping_cpu
=
torch
.
randperm
(
num_blocks
*
block_size
)[:
num_tokens
]
# Run reference implementation on CPU
block_indices
=
torch
.
div
(
slot_mapping_cpu
,
block_size
,
rounding_mode
=
"floor"
)
block_offsets
=
slot_mapping_cpu
%
block_size
for
i
in
range
(
num_tokens
):
block_idx
=
block_indices
[
i
]
block_offset
=
block_offsets
[
i
]
key_cache_cpu
[
block_idx
,
:,
block_offset
,
:]
=
key_cpu
[
i
]
value_cache_cpu
[
block_idx
,
:,
block_offset
,
:]
=
value_cpu
[
i
]
# Create XLA device tensors
device
=
torch
.
device
(
'xla'
)
key
=
key_cpu
.
to
(
device
)
value
=
value_cpu
.
to
(
device
)
key_cache
=
torch
.
zeros_like
(
key_cache_cpu
,
device
=
device
)
value_cache
=
torch
.
zeros_like
(
value_cache_cpu
,
device
=
device
)
slot_mapping
=
slot_mapping_cpu
.
to
(
device
)
kv_cache
=
torch
.
stack
([
key_cache
,
value_cache
])
# Run vectorized implementation on XLA device
reshape_and_cache
(
key
,
value
,
kv_cache
,
slot_mapping
)
key_cache
,
value_cache
=
torch
.
unbind
(
kv_cache
,
dim
=
0
)
# Move results back to CPU for comparison
key_cache_result
=
key_cache
.
cpu
()
value_cache_result
=
value_cache
.
cpu
()
# Assert results match
torch
.
testing
.
assert_close
(
key_cache_result
,
key_cache_cpu
,
rtol
=
1e-5
,
atol
=
1e-5
)
torch
.
testing
.
assert_close
(
value_cache_result
,
value_cache_cpu
,
rtol
=
1e-5
,
atol
=
1e-5
)
tests/neuron/1_core/test_layernorm.py
deleted
100644 → 0
View file @
848562bd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
import
torch
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.platforms
import
current_platform
@
pytest
.
mark
.
parametrize
(
"num_tokens,hidden_size,add_residual,dtype"
,
[
(
7
,
8
,
False
,
torch
.
half
),
(
83
,
768
,
False
,
torch
.
half
),
(
83
,
768
,
True
,
torch
.
half
),
(
83
,
768
,
True
,
torch
.
bfloat16
),
(
83
,
768
,
True
,
torch
.
float32
),
])
@
torch
.
inference_mode
()
def
test_rms_norm
(
num_tokens
:
int
,
hidden_size
:
int
,
add_residual
:
bool
,
dtype
:
torch
.
dtype
,
)
->
None
:
import
torch_xla.core.xla_model
as
xm
device
=
xm
.
xla_device
()
current_platform
.
seed_everything
(
0
)
torch
.
set_default_device
(
"cpu"
)
layer
=
RMSNorm
(
hidden_size
).
to
(
dtype
=
dtype
)
layer
.
weight
.
data
.
normal_
(
mean
=
1.0
,
std
=
0.1
)
scale
=
1
/
(
2
*
hidden_size
)
x
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
).
to
(
device
=
device
)
x
*=
scale
residual
=
torch
.
randn_like
(
x
)
*
scale
if
add_residual
else
None
residual_cpu
=
residual
.
cpu
()
if
add_residual
else
None
ref_out
=
layer
.
to
(
device
=
"cpu"
).
forward_native
(
x
.
cpu
(),
residual_cpu
)
assert
x
.
is_xla
,
"input tensor under testing is expected to be XLA tensor."
out
=
layer
.
to
(
device
=
device
)(
x
,
residual
)
# NOTE(woosuk): LayerNorm operators (including RMS) typically have larger
# numerical errors than other operators because they involve reductions.
# Therefore, we use a larger tolerance.
if
add_residual
:
assert
out
[
0
].
is_xla
,
"output tensor is expected to be XLA tensor"
torch
.
testing
.
assert_close
(
out
[
0
].
cpu
(),
ref_out
[
0
],
atol
=
1e-2
,
rtol
=
1e-2
)
torch
.
testing
.
assert_close
(
out
[
1
].
cpu
(),
ref_out
[
1
],
atol
=
1e-2
,
rtol
=
1e-2
)
else
:
assert
out
.
is_xla
,
"output tensor is expected to be XLA tensor"
torch
.
testing
.
assert_close
(
out
.
cpu
(),
ref_out
,
atol
=
1e-2
,
rtol
=
1e-2
)
tests/neuron/1_core/test_logits_processor.py
deleted
100644 → 0
View file @
848562bd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
random
from
unittest.mock
import
patch
import
pytest
import
torch
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.utils
import
is_pin_memory_available
class
MockLogitsProcessor
(
LogitsProcessor
):
def
__init__
(
self
,
vocab_size
:
int
,
scale
:
float
,
fake_logits
:
torch
.
Tensor
):
super
().
__init__
(
vocab_size
=
vocab_size
,
scale
=
scale
)
self
.
fake_logits
=
fake_logits
.
clone
()
def
forward
(
self
,
*
args
,
**
kwargs
):
with
patch
(
"vllm.model_executor.layers.logits_processor._prune_hidden_states"
,
lambda
x
,
y
:
x
),
patch
(
"vllm.model_executor.layers.logits_processor.LogitsProcessor._get_logits"
,
lambda
*
args
,
**
kwargs
:
self
.
fake_logits
):
return
super
().
forward
(
*
args
,
**
kwargs
)
def
_prepare_test
(
batch_size
:
int
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
MockLogitsProcessor
]:
vocab_size
=
32000
input_tensor
=
torch
.
rand
((
batch_size
,
1024
),
dtype
=
torch
.
float16
)
fake_logits
=
torch
.
full
((
batch_size
,
vocab_size
),
1e-2
,
dtype
=
input_tensor
.
dtype
)
logits_processor
=
MockLogitsProcessor
(
32000
,
0.5
,
fake_logits
)
return
input_tensor
,
fake_logits
,
logits_processor
RANDOM_SEEDS
=
list
(
range
(
8
))
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
def
test_logits_processors
(
seed
:
int
):
import
torch_xla.core.xla_model
as
xm
device
=
xm
.
xla_device
()
set_random_seed
(
seed
)
torch
.
set_default_device
(
"cpu"
)
batch_size
=
random
.
randint
(
1
,
256
)
input_tensor
,
fake_logits
,
logits_processor
=
_prepare_test
(
batch_size
)
# This sample logits processor gives infinite score to the i-th token,
# where i is the length of the input sequence.
# We therefore expect the output token sequence to be [0, 1, 2, ...]
def
pick_ith
(
token_ids
,
logits
):
logits
[
len
(
token_ids
)]
=
float
(
"inf"
)
return
logits
seq_group_metadata_list
=
[]
seq_lens
=
[]
for
i
in
range
(
batch_size
):
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
seq_data
=
{
0
:
SequenceData
.
from_seqs
([
1
,
2
,
3
])},
sampling_params
=
SamplingParams
(
temperature
=
0
,
logits_processors
=
[
pick_ith
]),
block_tables
=
{
0
:
[
1
]},
))
seq_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
seq_lens
,
query_lens
=
seq_lens
,
device
=
device
,
pin_memory
=
is_pin_memory_available
())
logits_processor_output
=
logits_processor
(
lm_head
=
None
,
hidden_states
=
input_tensor
,
sampling_metadata
=
sampling_metadata
)
fake_logits
*=
logits_processor
.
scale
torch
.
testing
.
assert_close
(
logits_processor_output
[:,
1
],
fake_logits
[:,
1
],
rtol
=
1e-4
,
atol
=
0.0
)
tests/neuron/1_core/test_neuron_model_runner.py
deleted
100644 → 0
View file @
848562bd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
os
from
unittest.mock
import
MagicMock
from
vllm.config
import
VllmConfig
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.platforms
import
current_platform
from
vllm.platforms.neuron
import
NeuronFramework
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
SequenceData
,
SequenceGroupMetadata
from
vllm.worker.neuron_model_runner
import
NeuronModelRunner
os
.
environ
[
'VLLM_NEURON_FRAMEWORK'
]
=
NeuronFramework
.
TRANSFORMERS_NEURONX
.
value
def
_create_neuron_model_runner
(
model
:
str
,
*
args
,
**
kwargs
)
->
NeuronModelRunner
:
engine_args
=
EngineArgs
(
model
,
*
args
,
**
kwargs
)
engine_config
=
engine_args
.
create_engine_config
()
vllm_config
=
VllmConfig
(
model_config
=
engine_config
.
model_config
,
parallel_config
=
engine_config
.
parallel_config
,
scheduler_config
=
engine_config
.
scheduler_config
,
device_config
=
engine_config
.
device_config
,
)
neuron_model_runner
=
NeuronModelRunner
(
vllm_config
=
vllm_config
)
return
neuron_model_runner
def
test_update_neuron_sampling_params_not_full_batch
():
os
.
environ
[
"NEURON_ON_DEVICE_SAMPLING_DISABLED"
]
=
"0"
model_runner
=
_create_neuron_model_runner
(
"facebook/opt-125m"
,
seed
=
0
,
dtype
=
"float16"
,
max_num_seqs
=
2
,
)
assert
not
model_runner
.
_on_device_sampling_disabled
# Test sampling param updating only when TNx is framework
# NxDI handles sampling parameter updating inside model
if
current_platform
.
use_transformers_neuronx
():
model_mock
=
MagicMock
()
model_runner
.
model
=
model_mock
seq_group_metadata_list
=
[
SequenceGroupMetadata
(
request_id
=
"test_0"
,
is_prompt
=
True
,
seq_data
=
{
0
:
SequenceData
.
from_seqs
([
1
,
2
,
3
])},
sampling_params
=
SamplingParams
(
temperature
=
0.5
,
top_k
=
1
,
top_p
=
0.5
),
block_tables
=
{
0
:
[
1
]},
)
]
model_runner
.
prepare_model_input
(
seq_group_metadata_list
)
# Index neuron sampling parameters based on block_tables indices.
# The first block_id of the sequence 0 is 1, so its parameters are
# placed at index 1. So the sampling parameters will be:
# Index 0: default sampling parameters
# Index 1: sequecne 0's sampling parameters.
neuron_sampling_params
=
(
model_runner
.
model_config
.
neuron_sampling_params
)
assert
neuron_sampling_params
.
temperature
==
[
1.0
,
0.5
]
assert
neuron_sampling_params
.
top_k
==
[
model_runner
.
_MAX_NEURON_SAMPLING_TOP_K
,
1
]
assert
neuron_sampling_params
.
top_p
==
[
1.0
,
0.5
]
model_mock
.
model
.
update_generation_config
.
assert_called_once_with
(
neuron_sampling_params
)
def
test_update_neuron_sampling_params_full_batch
():
os
.
environ
[
"NEURON_ON_DEVICE_SAMPLING_DISABLED"
]
=
"0"
model_runner
=
_create_neuron_model_runner
(
"facebook/opt-125m"
,
seed
=
0
,
dtype
=
"float16"
,
max_num_seqs
=
2
,
)
assert
not
model_runner
.
_on_device_sampling_disabled
# Test sampling param updating only when TNx is framework
# NxDI handles sampling parameter updating inside model
if
current_platform
.
use_transformers_neuronx
():
model_mock
=
MagicMock
()
model_runner
.
model
=
model_mock
seq_group_metadata_list
=
[
SequenceGroupMetadata
(
request_id
=
"test_0"
,
is_prompt
=
True
,
seq_data
=
{
0
:
SequenceData
.
from_seqs
([
1
,
2
,
3
])},
sampling_params
=
SamplingParams
(
temperature
=
0.5
,
top_k
=
1
,
top_p
=
0.5
),
block_tables
=
{
0
:
[
1
]},
),
SequenceGroupMetadata
(
request_id
=
"test_0"
,
is_prompt
=
True
,
seq_data
=
{
1
:
SequenceData
.
from_seqs
([
4
,
5
,
6
])},
sampling_params
=
SamplingParams
(
temperature
=
0.2
,
top_k
=
2
,
top_p
=
0.2
),
block_tables
=
{
1
:
[
0
]},
)
]
model_runner
.
prepare_model_input
(
seq_group_metadata_list
)
# Index neuron sampling parameters based on block_tables indices.
# The first block_id of the sequence 0 is 1, so its parameters are
# placed at index 1. So the sampling parameters will be:
# Index 0: sequence 1's sampling parameters
# Index 1: sequecne 0's sampling parameters.
neuron_sampling_params
=
(
model_runner
.
model_config
.
neuron_sampling_params
)
assert
neuron_sampling_params
.
temperature
==
[
0.2
,
0.5
]
assert
neuron_sampling_params
.
top_k
==
[
2
,
1
]
assert
neuron_sampling_params
.
top_p
==
[
0.2
,
0.5
]
model_mock
.
model
.
update_generation_config
.
assert_called_once_with
(
neuron_sampling_params
)
tests/neuron/1_core/test_neuron_quant.py
deleted
100644 → 0
View file @
848562bd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
vllm.model_executor.layers.quantization.neuron_quant
import
(
NeuronQuantConfig
)
def
test_get_supported_act_dtypes
():
neuron_quant_config
=
NeuronQuantConfig
()
supported_act_dtypes
=
neuron_quant_config
.
get_supported_act_dtypes
()
target_list
=
[
"any_dtype1"
,
"any_dtype2"
]
for
dtype
in
target_list
:
assert
dtype
in
supported_act_dtypes
tests/neuron/1_core/test_prefix_prefill.py
deleted
100644 → 0
View file @
848562bd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Optional
import
pytest
import
torch
import
torch.nn.functional
as
F
from
vllm.utils
import
cdiv
class
BlockDiagonalCausalFromBottomRightMask
:
@
staticmethod
def
_from_seqlens
(
query_lens
,
seq_lens
,
block_size
=
None
):
from
torch
import
logical_and
,
logical_or
contexted
=
block_size
is
None
context_lens
=
torch
.
tensor
(
seq_lens
)
-
torch
.
tensor
(
query_lens
)
n_queries
=
sum
(
query_lens
)
num_seqs
=
len
(
query_lens
)
if
contexted
:
key_lens_blockaligned
=
seq_lens
else
:
n_blocks_per_seq
=
(
context_lens
+
block_size
-
1
)
//
block_size
offset_per_seq
=
n_blocks_per_seq
*
block_size
key_lens_blockaligned
=
offset_per_seq
[:
num_seqs
].
tolist
()
n_keys
=
sum
(
key_lens_blockaligned
)
a
=
(
torch
.
arange
(
n_queries
).
reshape
(
n_queries
,
1
).
expand
(
n_queries
,
n_keys
))
b
=
torch
.
arange
(
n_keys
).
reshape
(
1
,
n_keys
).
expand
(
n_queries
,
n_keys
)
q_cumsum
=
torch
.
tensor
([
0
]
+
query_lens
).
cumsum
(
dim
=
0
)
k_cumsum
=
torch
.
tensor
([
0
]
+
key_lens_blockaligned
).
cumsum
(
dim
=
0
)
prior_mask
=
torch
.
zeros
(
n_queries
,
n_keys
)
new_masks
:
list
[
torch
.
Tensor
]
=
[]
for
seq_id
in
range
(
num_seqs
):
ri
=
q_cumsum
[
seq_id
]
ci
=
k_cumsum
[
seq_id
]
nr
=
query_lens
[
seq_id
]
if
contexted
:
nc
=
seq_lens
[
seq_id
]
a_offset
=
ci
+
nc
-
ri
-
nr
new_mask
=
(
a
+
a_offset
)
>=
b
else
:
nc
=
context_lens
[
seq_id
]
a_offset
=
ci
+
nc
-
1
new_mask
=
a_offset
>=
b
left_mask
=
b
>=
ci
top_mask
=
a
>=
ri
bottom_mask
=
a
<
(
ri
+
nr
)
new_mask
=
logical_and
(
logical_and
(
logical_and
(
new_mask
,
left_mask
),
top_mask
),
bottom_mask
,
)
prior_mask
=
logical_or
(
prior_mask
,
new_mask
)
new_masks
=
new_masks
+
[
new_mask
]
return
prior_mask
@
staticmethod
def
from_seqlens
(
query_lens
,
seq_lens
,
block_size
=
None
):
contexted
=
block_size
is
None
if
contexted
:
prior_mask
=
BlockDiagonalCausalFromBottomRightMask
.
_from_seqlens
(
query_lens
,
seq_lens
)
active_mask
=
None
else
:
prior_mask
=
BlockDiagonalCausalFromBottomRightMask
.
_from_seqlens
(
query_lens
,
seq_lens
,
block_size
)
active_mask
=
BlockDiagonalCausalFromBottomRightMask
.
_from_seqlens
(
query_lens
,
query_lens
)
return
prior_mask
,
active_mask
def
ref_softmax
(
x
:
torch
.
Tensor
,
dim
:
int
,
mixed_precision
=
False
,
return_max_reduce
=
False
):
max_value
=
torch
.
amax
(
x
,
dim
=
dim
,
keepdims
=
True
)
exp
=
torch
.
exp
(
x
-
max_value
)
if
mixed_precision
:
sum_value
=
torch
.
sum
(
exp
.
astype
(
torch
.
float32
),
dim
=
dim
,
keepdims
=
True
).
astype
(
x
.
dtype
)
else
:
sum_value
=
torch
.
sum
(
exp
,
dim
=
dim
,
keepdims
=
True
)
if
return_max_reduce
:
return
exp
/
sum_value
,
max_value
,
torch
.
reciprocal
(
sum_value
)
return
exp
/
sum_value
def
ref_masked_attention
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
scale
:
float
,
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
return_max_reduce
:
Optional
[
bool
]
=
False
,
)
->
torch
.
Tensor
:
scaled_qk
=
scale
*
torch
.
einsum
(
"qhd,khd->hqk"
,
query
,
key
).
float
()
if
attn_mask
is
not
None
:
masked_score
=
scaled_qk
+
attn_mask
.
float
()
if
return_max_reduce
:
norm_score
,
cached_max
,
cached_sum_reciprocal
=
ref_softmax
(
masked_score
,
dim
=-
1
,
return_max_reduce
=
True
)
else
:
norm_score
=
ref_softmax
(
masked_score
,
dim
=-
1
)
out
=
torch
.
einsum
(
"hqk,khd->qhd"
,
norm_score
.
to
(
value
.
dtype
),
value
)
if
return_max_reduce
:
return
(
out
,
cached_max
,
cached_sum_reciprocal
,
norm_score
,
masked_score
,
scaled_qk
,
)
else
:
return
(
out
,
)
def
ref_context_attention
(
query
,
key
,
value
,
query_lens
,
seq_lens
,
head_size
,
num_queries_per_kv
,
return_max_reduce
=
False
,
):
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
if
num_queries_per_kv
>
1
:
# Handle MQA and GQA
key
=
torch
.
repeat_interleave
(
key
,
num_queries_per_kv
,
dim
=
1
)
value
=
torch
.
repeat_interleave
(
value
,
num_queries_per_kv
,
dim
=
1
)
attn_mask
,
_
=
BlockDiagonalCausalFromBottomRightMask
.
from_seqlens
(
query_lens
,
seq_lens
)
# convert binary mask to -inf values
attn_mask
=
torch
.
logical_not
(
attn_mask
)
attn_mask
=
attn_mask
.
float
()
*
-
30000
output
,
*
debug_tensors
=
ref_masked_attention
(
query
,
key
,
value
,
scale
,
attn_mask
,
return_max_reduce
=
return_max_reduce
,
)
output
=
output
.
unsqueeze
(
1
)
if
return_max_reduce
:
cached_max
,
cached_sum_reciprocal
,
lse
,
masked_score
,
scaled_qk
=
(
debug_tensors
)
return
(
output
,
cached_max
,
cached_sum_reciprocal
,
lse
,
masked_score
,
scaled_qk
,
)
else
:
return
output
def
sample_inputs
(
prefill_batch_size
,
decode_batch_size
,
min_query_len
,
max_query_len
,
min_ctx_len
,
max_ctx_len
,
block_size
,
num_heads
,
num_kv_heads
,
head_size
,
dtype
,
):
batch_size
=
prefill_batch_size
+
decode_batch_size
max_model_len
=
(
max_query_len
+
max_ctx_len
)
*
4
max_block_per_request
=
max_model_len
//
block_size
cache_size
=
(
batch_size
*
max_block_per_request
)
+
2
prefill_ctx_lens
=
torch
.
randint
(
min_ctx_len
,
max_ctx_len
+
1
,
(
prefill_batch_size
,
),
dtype
=
torch
.
long
).
tolist
()
decode_ctx_lens
=
torch
.
randint
(
min_ctx_len
,
max_ctx_len
+
1
,
(
decode_batch_size
,
),
dtype
=
torch
.
long
).
tolist
()
ctx_lens
=
prefill_ctx_lens
+
decode_ctx_lens
query_lens
=
torch
.
randint
(
min_query_len
,
max_query_len
+
1
,
(
prefill_batch_size
,
),
dtype
=
torch
.
long
,
).
tolist
()
+
[
1
for
_
in
range
(
decode_batch_size
)]
seq_lens
=
[
a
+
b
for
a
,
b
in
zip
(
query_lens
,
ctx_lens
)]
num_tokens
=
sum
(
query_lens
)
query
=
torch
.
empty
(
num_tokens
,
num_heads
,
head_size
,
dtype
=
dtype
)
query
.
uniform_
(
-
1
,
1
)
torch
.
empty
(
num_tokens
,
num_heads
,
head_size
,
dtype
=
dtype
)
kv
=
torch
.
empty
(
sum
(
seq_lens
),
2
,
num_kv_heads
,
head_size
,
dtype
=
dtype
)
kv
.
uniform_
(
-
1
,
1
)
key
,
value
=
kv
.
unbind
(
dim
=
1
)
k_cache
=
torch
.
zeros
(
cache_size
,
block_size
,
num_kv_heads
,
head_size
,
dtype
=
dtype
)
v_cache
=
torch
.
zeros
(
cache_size
,
block_size
,
num_kv_heads
,
head_size
,
dtype
=
dtype
)
k
=
torch
.
zeros
(
sum
(
query_lens
),
num_kv_heads
,
head_size
,
dtype
=
dtype
)
v
=
torch
.
zeros
(
sum
(
query_lens
),
num_kv_heads
,
head_size
,
dtype
=
dtype
)
values
=
torch
.
arange
(
0
,
cache_size
,
dtype
=
torch
.
long
)
values
=
values
[
torch
.
randperm
(
cache_size
)]
block_table
=
values
[:
batch_size
*
max_block_per_request
].
view
(
batch_size
,
max_block_per_request
)
b_ctx_len
=
torch
.
tensor
(
ctx_lens
,
dtype
=
torch
.
long
)
b_start_loc
=
torch
.
cumsum
(
torch
.
tensor
([
0
]
+
query_lens
[:
-
1
],
dtype
=
torch
.
long
),
dim
=
0
)
# copy kv to cache
b_seq_start_loc
=
torch
.
cumsum
(
torch
.
tensor
([
0
]
+
seq_lens
[:
-
1
],
dtype
=
torch
.
long
),
dim
=
0
)
for
i
in
range
(
batch_size
):
for
j
in
range
(
query_lens
[
i
]):
k
[
b_start_loc
[
i
]
+
j
].
copy_
(
key
[
b_seq_start_loc
[
i
]
+
b_ctx_len
[
i
]
+
j
])
v
[
b_start_loc
[
i
]
+
j
].
copy_
(
value
[
b_seq_start_loc
[
i
]
+
b_ctx_len
[
i
]
+
j
])
cur_ctx
=
0
block_id
=
0
while
cur_ctx
<
b_ctx_len
[
i
]:
start_loc
=
b_seq_start_loc
[
i
]
+
cur_ctx
if
cur_ctx
+
block_size
>
b_ctx_len
[
i
]:
end_loc
=
b_seq_start_loc
[
i
]
+
b_ctx_len
[
i
]
else
:
end_loc
=
start_loc
+
block_size
start_slot
=
block_table
[
i
,
block_id
]
*
block_size
end_slot
=
start_slot
+
end_loc
-
start_loc
k_cache
.
view
(
-
1
,
num_kv_heads
,
head_size
)[
start_slot
:
end_slot
].
copy_
(
key
[
start_loc
:
end_loc
])
v_cache
.
view
(
-
1
,
num_kv_heads
,
head_size
)[
start_slot
:
end_slot
].
copy_
(
value
[
start_loc
:
end_loc
])
cur_ctx
+=
block_size
block_id
+=
1
kv_cache
=
torch
.
stack
([
k_cache
,
v_cache
])
return
(
query
,
k
,
v
,
kv_cache
,
block_table
,
key
,
value
,
query_lens
,
seq_lens
,
)
def
get_active_block_tables
(
block_tables
,
query_lens
,
seq_lens
,
block_size
,
num_blocks
):
context_lens
=
seq_lens
-
query_lens
blocks_per_seq
=
(
context_lens
+
block_size
-
1
)
//
block_size
num_seqs
=
len
(
seq_lens
)
active_blocks
:
list
[
int
]
=
[]
for
seq_id
in
range
(
num_seqs
):
active_blocks
=
(
active_blocks
+
block_tables
[
seq_id
,
:
blocks_per_seq
[
seq_id
]].
tolist
())
return
F
.
pad
(
torch
.
tensor
(
active_blocks
,
dtype
=
torch
.
int32
),
(
0
,
num_blocks
-
len
(
active_blocks
)),
"constant"
,
0
,
)
@
pytest
.
mark
.
parametrize
(
"prefill_batch_size,decode_batch_size,block_size,large_tile_size,num_heads,num_queries_per_kv,head_size,mixed_precision"
,
[
# Test minimal configurations (small block size)
(
1
,
199
,
1
,
512
,
4
,
2
,
8
,
False
),
# minimal block size, small dimensions
(
1
,
199
,
1
,
512
,
4
,
2
,
8
,
True
),
# same with mixed precision
# Test common/medium configurations
(
4
,
12
,
32
,
2048
,
32
,
8
,
64
,
False
),
# common case, larger heads
(
4
,
12
,
32
,
2048
,
16
,
4
,
32
,
True
),
# medium size, mixed precision, grouped-query attention (GQA)
# Test large configurations
(
4
,
12
,
256
,
8192
,
8
,
1
,
128
,
False
),
# large blocks, large head size
(
4
,
12
,
256
,
8192
,
64
,
8
,
64
,
True
),
# large blocks, many heads
# Test asymmetric configurations
(
2
,
24
,
64
,
4096
,
12
,
4
,
96
,
False
),
# varied batch sizes
(
8
,
8
,
128
,
2048
,
24
,
2
,
48
,
True
),
# balanced batches
# Test edge cases
(
1
,
128
,
16
,
1024
,
4
,
2
,
16
,
False
),
# large decode batch
(
16
,
4
,
8
,
1024
,
4
,
2
,
128
,
True
),
# large prefill batch
(
4
,
12
,
32
,
2048
,
16
,
1
,
32
,
True
),
# multi-head attention (MHA)
(
4
,
12
,
32
,
2048
,
16
,
16
,
32
,
True
),
# multi-query attention (MQA)
])
@
torch
.
inference_mode
()
def
test_contexted_kv_attention
(
monkeypatch
:
pytest
.
MonkeyPatch
,
prefill_batch_size
:
int
,
decode_batch_size
:
int
,
num_heads
:
int
,
num_queries_per_kv
:
int
,
head_size
:
int
,
block_size
:
int
,
large_tile_size
,
mixed_precision
:
bool
,
)
->
None
:
import
torch_xla.core.xla_model
as
xm
from
vllm.attention.ops.nki_flash_attn
import
(
flash_attn_varlen_nkifunc
,
reorder_context_mask
)
assert
large_tile_size
%
block_size
==
0
device
=
xm
.
xla_device
()
compiler_flags_str
=
" "
.
join
([
"-O1"
,
"--retry_failed_compilation"
,
])
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"NEURON_CC_FLAGS"
,
compiler_flags_str
)
torch
.
manual_seed
(
0
)
torch
.
set_printoptions
(
sci_mode
=
False
)
torch
.
set_default_device
(
"cpu"
)
dtype
=
torch
.
float32
min_ctx_len
=
32
max_ctx_len
=
1024
min_query_len
=
16
max_query_len
=
512
num_kv_heads
=
num_heads
//
num_queries_per_kv
(
query
,
k_active
,
v_active
,
kv_cache
,
block_table
,
key
,
value
,
query_lens
,
seq_lens
,
)
=
sample_inputs
(
prefill_batch_size
=
prefill_batch_size
,
decode_batch_size
=
decode_batch_size
,
min_query_len
=
min_query_len
,
max_query_len
=
max_query_len
,
min_ctx_len
=
min_ctx_len
,
max_ctx_len
=
max_ctx_len
,
block_size
=
block_size
,
num_heads
=
num_heads
,
num_kv_heads
=
num_kv_heads
,
head_size
=
head_size
,
dtype
=
dtype
,
)
output_ref
=
ref_context_attention
(
query
,
key
,
value
,
query_lens
,
seq_lens
,
head_size
,
num_queries_per_kv
,
return_max_reduce
=
False
,
)
# build neuron program
B_P_SIZE
=
128
assert
(
large_tile_size
>=
B_P_SIZE
),
f
"Expect
{
large_tile_size
=
}
to be larger than
{
B_P_SIZE
=
}
"
def
pad_to_multiple
(
a
,
b
):
return
cdiv
(
a
,
b
)
*
b
def
pad_to_next_power_of_2
(
a
):
assert
a
>
0
return
2
**
int
(
a
-
1
).
bit_length
()
# calculate input shapes
max_num_queries
=
pad_to_next_power_of_2
(
sum
(
query_lens
))
context_lens
=
torch
.
tensor
(
seq_lens
)
-
torch
.
tensor
(
query_lens
)
num_active_blocks
=
cdiv
(
context_lens
,
block_size
).
sum
().
item
()
num_active_blocks
=
pad_to_multiple
(
num_active_blocks
,
large_tile_size
//
block_size
)
context_kv_len
=
num_active_blocks
*
block_size
assert
(
context_kv_len
%
large_tile_size
==
0
),
f
"invalid context_kv_len=
{
context_kv_len
}
"
# pad QKV tensors
pad_dims
=
(
0
,
0
,
0
,
0
,
0
,
max_num_queries
-
query
.
shape
[
0
],
)
query
=
F
.
pad
(
query
,
pad_dims
,
"constant"
,
0
)
k
=
F
.
pad
(
k_active
,
pad_dims
,
"constant"
,
0
)
v
=
F
.
pad
(
v_active
,
pad_dims
,
"constant"
,
0
)
# permute QKV tensors
# query: (1, n_heads, d, seq_q)
# key: (1, n_kv_heads, d, seq_k)
# value: (1, n_kv_heads, seq_v, d)
query
=
query
.
unsqueeze
(
0
).
permute
(
0
,
2
,
3
,
1
).
contiguous
()
k
=
k
.
unsqueeze
(
0
).
permute
(
0
,
2
,
3
,
1
).
contiguous
()
v
=
v
.
unsqueeze
(
0
).
permute
(
0
,
2
,
1
,
3
).
contiguous
()
kv_cache
=
kv_cache
.
permute
(
0
,
1
,
3
,
2
,
4
).
contiguous
()
# transform block table
active_block_table
=
get_active_block_tables
(
block_table
.
cpu
(),
torch
.
tensor
(
query_lens
).
cpu
(),
torch
.
tensor
(
seq_lens
).
cpu
(),
block_size
,
num_active_blocks
,
)
# Build attention masks
prior_mask
,
active_mask
=
(
BlockDiagonalCausalFromBottomRightMask
.
from_seqlens
(
query_lens
,
seq_lens
,
block_size
=
block_size
))
prior_mask_padded
=
F
.
pad
(
prior_mask
,
(
0
,
context_kv_len
-
prior_mask
.
shape
[
1
],
0
,
max_num_queries
-
prior_mask
.
shape
[
0
],
),
"constant"
,
0
,
).
bool
()
active_mask_padded
=
F
.
pad
(
active_mask
,
(
0
,
max_num_queries
-
active_mask
.
shape
[
1
],
0
,
max_num_queries
-
active_mask
.
shape
[
0
],
),
"constant"
,
0
,
).
bool
()
attn_mask
=
torch
.
concat
([
prior_mask_padded
,
active_mask_padded
],
dim
=
1
)
attn_mask
=
reorder_context_mask
(
attn_mask
,
large_tile_size
,
block_size
)
input_args
=
(
query
.
to
(
device
=
device
),
k
.
to
(
device
=
device
),
v
.
to
(
device
=
device
),
kv_cache
.
to
(
device
=
device
),
active_block_table
.
to
(
device
=
device
),
attn_mask
.
to
(
device
=
device
),
)
input_kwargs
=
dict
(
n_kv_head
=
num_kv_heads
,
head_size
=
head_size
,
mixed_precision
=
mixed_precision
,
LARGE_TILE_SZ
=
large_tile_size
,
)
output_nki
=
flash_attn_varlen_nkifunc
(
*
input_args
,
**
input_kwargs
)
num_actual_tokens
=
sum
(
query_lens
)
# - o: shape (bs, n_heads, seq_q, d) -> (bs, seq_q, n_heads, d)
output_nki
=
output_nki
.
cpu
().
permute
(
0
,
2
,
1
,
3
)
output_nki
=
output_nki
[
0
,
:
num_actual_tokens
,
:,
:]
output_ref_padded
=
F
.
pad
(
output_ref
,
(
0
,
0
,
0
,
0
,
0
,
0
,
0
,
max_num_queries
-
output_ref
.
shape
[
0
]),
"constant"
,
0
,
)
output_ref
=
output_ref_padded
.
transpose
(
0
,
1
)[
0
,
:
num_actual_tokens
,
:,
:]
torch
.
testing
.
assert_close
(
output_nki
,
output_ref
,
atol
=
1e-2
,
rtol
=
0
)
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment