Unverified Commit 4172235a authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[V0 deprecation] Deprecate V0 Neuron backend (#21159)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent 848562bd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for miscellaneous utilities
"""
import pytest
import torch
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform
@pytest.mark.parametrize(
"max_position,is_neox_style,rotary_dim,head_size,seq_len,use_key", [
(16, False, 32, 32, 1024, True),
(16, False, 32, 128, 1024, True),
(16, True, 32, 32, 1024, True),
(16, True, 32, 128, 1024, True),
(16, False, 32, 128, 1024, False),
(16, True, 32, 128, 1024, False),
])
def test_rotary_embedding_opcheck(max_position, is_neox_style, rotary_dim,
head_size, seq_len, use_key):
import torch_xla.core.xla_model as xm
device = xm.xla_device()
current_platform.seed_everything(0)
torch.set_default_device("cpu")
batch_size = 1
base = 10000
num_heads = 8
rot = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style, torch.float32)
positions = torch.randint(0,
max_position, (batch_size, seq_len),
device="cpu")
query = torch.randn(batch_size,
seq_len,
num_heads * head_size,
dtype=torch.float32,
device="cpu")
key = torch.randn_like(query) if use_key else None
assert positions.is_cpu, \
"reference input tensor is expected to be CPU tensor."
ref_query, ref_key = rot.to(device="cpu").forward_native(
positions, query, key)
out_query, out_key = rot.to(device=device).forward_neuron(
positions.to(device=device), query.to(device=device),
key.to(device=device) if key is not None else None)
if use_key:
assert out_query.is_xla and out_key.is_xla, \
"output tensor is expected to be XLA tensor"
torch.testing.assert_close(out_key.cpu(),
ref_key,
atol=1e-2,
rtol=1e-2)
else:
assert out_key is None, "expected returned key to be None"
assert out_query.is_xla, \
"output tensor is expected to be XLA tensor"
torch.testing.assert_close(out_query.cpu(),
ref_query,
atol=1e-2,
rtol=1e-2)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from typing import Callable
from unittest.mock import patch
import pytest
import torch
import torch_xla.distributed.xla_multiprocessing as xmp
from typing_extensions import ParamSpec
from vllm.distributed.communication_op import (
tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.utils import get_distributed_init_method, get_open_port
_P = ParamSpec("_P")
def reinitialize_neuron_runtime(f: Callable[_P, None]) -> Callable[_P, None]:
"""Decorator to reinitialize the Neuron Runtime before executing a test.
This is necessary for distributed tests which need to reallocate Neuron
Cores to separate subprocesses.
"""
@functools.wraps(f)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
runtime = torch.classes.neuron.Runtime()
runtime.initialize()
runtime.unsafe_close()
f(*args, **kwargs)
runtime.initialize()
return wrapper
def all_gather_test_worker(index, tp_degree, distributed_init_method):
init_distributed_environment(tp_degree,
index,
distributed_init_method,
index,
backend="xla")
ensure_model_parallel_initialized(tp_degree, 1)
num_dimensions = 3
tensor_size = list(range(2, num_dimensions + 2))
total_size = 1
for s in tensor_size:
total_size *= s
all_gather_dimension = -1
all_tensors = [
torch.arange(total_size, dtype=torch.float32,
device="xla").reshape(tensor_size) * (r + 1)
for r in range(tp_degree)
]
expected = torch.cat(all_tensors, dim=all_gather_dimension)
t = all_tensors[index % tp_degree]
t = tensor_model_parallel_all_gather(t, all_gather_dimension)
torch.testing.assert_close(t, expected)
def all_reduce_test_worker(index, tp_degree, distributed_init_method):
init_distributed_environment(tp_degree,
index,
distributed_init_method,
index,
backend="xla")
ensure_model_parallel_initialized(tp_degree, 1)
num_elements = 8
all_tensors = [
torch.arange(num_elements, dtype=torch.float32, device="xla") * (r + 1)
for r in range(tp_degree)
]
expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
t = all_tensors[index % tp_degree]
t = tensor_model_parallel_all_reduce(t)
torch.testing.assert_close(t, expected)
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("test_target",
[all_reduce_test_worker, all_gather_test_worker])
@reinitialize_neuron_runtime
def test_neuron_multi_process_tensor_parallel(monkeypatch, tp_size,
test_target):
with patch('torch_xla._XLAC._xla_runtime_is_initialized',
return_value=False):
distributed_init_method = get_distributed_init_method(
"127.0.0.1", get_open_port())
monkeypatch.setenv("VLLM_USE_V1", "1")
monkeypatch.setenv("NEURONCORE_NUM_DEVICES", str(tp_size))
monkeypatch.setenv("NEURON_PJRT_PROCESSES_NUM_DEVICES",
','.join(['1' for _ in range(tp_size)]))
xmp.spawn(test_target, args=(tp_size, distributed_init_method))
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import os
import shutil
import tempfile
import torch
from huggingface_hub import snapshot_download
from safetensors import safe_open
from vllm import LLM, SamplingParams
def patch_eagle_draft_with_lm_head(target_model_id: str,
draft_model_id: str) -> str:
# In NxDI, draft model checkpoint must include lm_head weights from target
# model. For more details see https://awsdocs-neuron.readthedocs-hosted.com
# /en/latest/libraries/nxd-inference/developer_guides/feature-guide.html
# #eagle-checkpoint-compatibility
final_draft_dir = "/tmp/patched_eagle_draft"
with tempfile.TemporaryDirectory() as tmp_dir:
target_dir = snapshot_download(repo_id=target_model_id,
local_dir=os.path.join(
tmp_dir, "target"))
draft_dir = snapshot_download(repo_id=draft_model_id,
local_dir=os.path.join(tmp_dir, "draft"))
lm_head_key = "lm_head.weight"
index_path = os.path.join(target_dir, "model.safetensors.index.json")
with open(index_path) as f:
index = json.load(f)
shard_name = index["weight_map"][lm_head_key]
target_safetensor_path = os.path.join(target_dir, shard_name)
with safe_open(target_safetensor_path, framework="pt") as f:
target_lm_head = f.get_tensor(lm_head_key)
draft_path = os.path.join(draft_dir, "pytorch_model.bin")
draft_state_dict = torch.load(draft_path, map_location="cpu")
draft_state_dict[lm_head_key] = target_lm_head.to(torch.float16)
torch.save(draft_state_dict, draft_path)
shutil.copytree(draft_dir, final_draft_dir, dirs_exist_ok=True)
return final_draft_dir
def test_eagle():
patched_draft_path = patch_eagle_draft_with_lm_head(
target_model_id="meta-llama/Llama-2-7b-hf",
draft_model_id="yuhuili/EAGLE-llama2-chat-7B")
llm = LLM(
model="meta-llama/Llama-2-7b-hf",
speculative_config={
"model": patched_draft_path,
"num_speculative_tokens": 5,
"max_model_len": 128
},
max_num_seqs=1,
max_model_len=128,
tensor_parallel_size=2,
override_neuron_config={
"enable_eagle_speculation": True,
"enable_fused_speculation": True,
"fused_qkv": True
},
)
prompts = [
"The president of the United States is",
]
outputs = llm.generate(prompts, SamplingParams(top_k=1))
expected_output = " the head of state and head of government of " \
"the United States. The president direct"
for output in outputs:
generated_text = output.outputs[0].text
print(f"Prompt: {output.prompt!r}, Generated text: {generated_text!r}")
assert (expected_output == generated_text)
print("Neuron Eagle speculation test passed.")
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm import LLM, SamplingParams
def test_mistral():
llm = LLM(model="mistralai/Mistral-7B-v0.1",
tensor_parallel_size=2,
max_num_seqs=4,
max_model_len=128,
override_neuron_config={
"sequence_parallel_enabled": False,
"skip_warmup": True
})
# Send more prompts than the compiled batch size (4) and request
# varying generation lengths to test accuracy related to Neuron
# specific sequence id sorting.
prompts = [
"The president of the United States is",
"The capital of France is",
"What is Annapurna labs?",
"I believe the meaning of life is",
"Tell me a story about a brave knight",
"Hello, my name is Llama",
]
sampling_params = [
SamplingParams(top_k=1, max_tokens=10),
SamplingParams(top_k=1, max_tokens=20),
SamplingParams(top_k=1, max_tokens=30),
SamplingParams(top_k=1, max_tokens=40),
SamplingParams(top_k=1, max_tokens=50),
SamplingParams(top_k=1, max_tokens=60)
]
outputs = llm.generate(prompts, sampling_params)
expected_outputs = [
" the most powerful person in the world. He is",
" a city of many faces. It is a city of history, culture, art, "
"fashion, and",
"\n\nAnnapurna Labs is a semiconductor company that was founded "
"in 2013 by Amazon. The company is",
" to be happy.\n\nI believe that happiness is a choice.\n\nI "
"believe that happiness is a state of mind.\n\nI believe that "
"happiness is a journey.\n\nI believe",
" who rescued a princess from a dragon.\n\nTell me a story about"
" a princess who rescued herself from a dragon.\n\nTell me a "
"story about a princess who rescued herself from a dragon and "
"then rescued a knight from",
" and I am a 10 year old male. I am a very friendly and "
"affectionate boy who loves to be around people. I am a very "
"active boy who loves to play and run around. I am a very smart "
"boy who loves to learn new things. I am a very loyal boy"
]
for expected_output, output in zip(expected_outputs, outputs):
generated_text = output.outputs[0].text
print(f"Prompt: {output.prompt!r}, Generated text: {generated_text!r}")
assert (expected_output == generated_text)
print("Neuron Mistral test passed.")
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from huggingface_hub import snapshot_download
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
def test_llama_single_lora():
sql_lora_files = snapshot_download(
repo_id="yard1/llama-2-7b-sql-lora-test")
llm = LLM(model="meta-llama/Llama-2-7b-hf",
tensor_parallel_size=2,
max_num_seqs=4,
max_model_len=512,
override_neuron_config={
"sequence_parallel_enabled": False,
"skip_warmup": True,
"lora_modules": [{
"name": "lora_id_1",
"path": sql_lora_files
}]
},
enable_lora=True,
max_loras=1,
max_lora_rank=256,
device="neuron")
"""For multi-lora requests using NxDI as the backend, only the lora_name
needs to be specified. The lora_id and lora_path are supplied at the LLM
class/server initialization, after which the paths are handled by NxDI"""
lora_req_1 = LoRARequest("lora_id_1", 0, " ")
prompts = [
"The president of the United States is",
"The capital of France is",
]
outputs = llm.generate(prompts,
SamplingParams(top_k=1),
lora_request=[lora_req_1, lora_req_1])
expected_outputs = [
" the head of state and head of government of the United States. "
"The president direct",
" a city of contrasts. The city is home to the Eiffel Tower"
]
for expected_output, output in zip(expected_outputs, outputs):
generated_text = output.outputs[0].text
assert (expected_output == generated_text)
def test_llama_multiple_lora():
sql_lora_files = snapshot_download(
repo_id="yard1/llama-2-7b-sql-lora-test")
llm = LLM(model="meta-llama/Llama-2-7b-hf",
tensor_parallel_size=2,
max_num_seqs=4,
max_model_len=512,
override_neuron_config={
"sequence_parallel_enabled":
False,
"skip_warmup":
True,
"lora_modules": [{
"name": "lora_id_1",
"path": sql_lora_files
}, {
"name": "lora_id_2",
"path": sql_lora_files
}]
},
enable_lora=True,
max_loras=2,
max_lora_rank=256,
device="neuron")
"""For multi-lora requests using NxDI as the backend, only the lora_name
needs to be specified. The lora_id and lora_path are supplied at the LLM
class/server initialization, after which the paths are handled by NxDI"""
lora_req_1 = LoRARequest("lora_id_1", 0, " ")
lora_req_2 = LoRARequest("lora_id_2", 1, " ")
prompts = [
"The president of the United States is",
"The capital of France is",
]
outputs = llm.generate(prompts,
SamplingParams(top_k=1),
lora_request=[lora_req_1, lora_req_2])
expected_outputs = [
" the head of state and head of government of the United States. "
"The president direct",
" a city of contrasts. The city is home to the Eiffel Tower"
]
for expected_output, output in zip(expected_outputs, outputs):
generated_text = output.outputs[0].text
assert (expected_output == generated_text)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import neuronxcc.nki.isa as nisa
import neuronxcc.nki.language as nl
import numpy as np
import torch
from neuronxcc import nki
from neuronxcc.nki.language import par_dim
from vllm.utils import cdiv
def is_power_of_2(x):
return x > 0 and (x & (x - 1)) == 0
@nki.jit
def load_block_tables(block_tables_hbm, num_tiles, num_blocks_per_tile):
"""
Load block tables from HBM into SRAM
`block_tables_hbm` has shape `(num_tiles * num_blocks_per_tile, )`.
In case `num_tiles > B_P_SIZE`, we need further tile `num_tile` dimension.
"""
B_P_SIZE = 128
# reshape as `(num_tiles, num_blocks_per_tile)`
assert len(block_tables_hbm.shape) == 1
(num_total_blocks, ) = block_tables_hbm.shape
assert num_blocks_per_tile * num_tiles == num_total_blocks
block_tables_hbm = block_tables_hbm.reshape(
(num_tiles, num_blocks_per_tile))
block_tables_sbuf = nl.zeros(
(cdiv(num_tiles, B_P_SIZE), par_dim(B_P_SIZE), num_blocks_per_tile),
dtype=nl.int32,
)
for i in nl.affine_range(cdiv(num_tiles, B_P_SIZE)):
i_p = nl.arange(B_P_SIZE)[:, None]
i_f = nl.arange(num_blocks_per_tile)[None, :]
block_tables_sbuf[i, i_p, i_f] = nl.load(
block_tables_hbm[i_p + i * B_P_SIZE, i_f],
dtype=nl.int32,
mask=(i_p + i * B_P_SIZE < num_tiles),
)
return block_tables_sbuf
@nki.jit
def transform_block_tables_for_indirect_load(
block_tables,
block_size_tiling_factor,
num_head,
head_id,
):
"""
This function does two things:
1. calculate new `block_tables` for a `head_id` after flattening
`num_block`, `num_head`, and `block_size_tiling_factor` dimensions
2. transpose the result so that `block_table` for each tile is mapped to
SBUF Partition dimension for vectorized DMA
Tiling trick to further improve DMA performance:
Given KV cache shape `(num_block, num_head, block_size, D)`, when loading M
blocks of a given `head_id` from HBM, the load `cache[block_tables,
head_id]` has shape `(M, block_size, D)`. If M < B_P_SIZE = 128, DMA may not
fully utilize hardware parallelization. The solution is to tile `block_size`
into `(block_size_tiling_factor, tiled_block_size)` s.t. `M *
block_size_tiling_factor = B_P_SIZE`. After tiling, KV cache has shape
`(num_block, num_head, block_size_tiling_factor, tiled_block_size, D)`.
Note:
We don't further tile D dimension as small DMA size also hurts performance.
"""
B_P_SIZE = 128
num_partitions, num_tiles_per_partition, num_blocks_per_tile = (
block_tables.shape)
assert num_tiles_per_partition == B_P_SIZE
assert is_power_of_2(
num_blocks_per_tile), f"{num_blocks_per_tile=} is not power of 2"
num_loads = cdiv(num_blocks_per_tile, B_P_SIZE)
block_tables_transposed = nl.ndarray(
(
num_loads,
par_dim(B_P_SIZE),
num_partitions * num_tiles_per_partition,
),
dtype=nl.int32,
)
# prepare iota ahead of time to avoid repeatedly using Gpsimd
if num_head > 1:
head_id = nisa.iota(head_id, dtype=nl.int32).reshape((1, 1))
head_id = nl.transpose(
head_id.broadcast_to((1, num_tiles_per_partition)))
if num_blocks_per_tile > 1:
head_id = head_id.broadcast_to(
(num_tiles_per_partition, num_blocks_per_tile))
if block_size_tiling_factor > 1:
broadcast_shape = (
num_tiles_per_partition,
num_blocks_per_tile,
block_size_tiling_factor,
)
offset = nisa.iota(nl.arange(block_size_tiling_factor)[None, None, :],
dtype=nl.int32).broadcast_to(broadcast_shape)
for partition_id in nl.affine_range(num_partitions):
block_tables_partition = block_tables[partition_id]
if num_head > 1:
# fuse num_block and num_head dimension
block_tables_partition = block_tables_partition * num_head + head_id
if block_size_tiling_factor > 1:
# need to apply block size tiling trick
assert num_blocks_per_tile * block_size_tiling_factor == B_P_SIZE
block_tables_partition = ((block_tables_partition *
block_size_tiling_factor).reshape(
(num_tiles_per_partition,
num_blocks_per_tile,
1)).broadcast_to(broadcast_shape))
new_block_tables = block_tables_partition + offset
new_block_tables = new_block_tables.reshape(
(num_tiles_per_partition, B_P_SIZE))
else:
new_block_tables = block_tables_partition
# transpose the block table so that it can be used by vector DGE
for i in nl.affine_range(num_loads):
i_p = nl.arange(B_P_SIZE)[:, None]
i_f = (partition_id * num_tiles_per_partition +
nl.arange(num_tiles_per_partition)[None, :])
block_tables_transposed[i, i_p, i_f] = nl.transpose(
new_block_tables[:, nl.ds(i * B_P_SIZE, B_P_SIZE)])
return block_tables_transposed
@nki.jit
def load_kv_tile_from_cache(
cur_k_tile,
cur_v_tile,
kv_cache,
block_tables,
large_k_tile_idx,
num_blocks_per_large_tile,
tiled_block_size,
B_P_SIZE,
B_D_SIZE,
):
"""
Load KV cache and transform Key and Value into layout required by Matmul
Vectorized DMA Load layout:
Key and Value: (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE)
Layout used by attention matmuls:
Key: (par_dim(B_D_SIZE), seqlen_kv)
Value: (seqlen_kv // B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE)
equivalent to (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE)
"""
# load key cache
num_loads = cdiv(num_blocks_per_large_tile, B_P_SIZE)
for load_idx in nl.affine_range(num_loads):
i_p = nl.arange(B_P_SIZE)[:, None]
i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :]
loaded = nl.load(kv_cache[0, block_tables[load_idx, i_p,
large_k_tile_idx], i_f])
if cur_k_tile.dtype != loaded.dtype:
loaded = nl.copy(loaded, dtype=cur_k_tile.dtype)
# Transpose SBUF tensor using PE
for tb_i in nl.affine_range(tiled_block_size):
cur_k_tile[
:,
nl.ds(
load_idx * B_P_SIZE * tiled_block_size + tb_i * B_P_SIZE,
B_P_SIZE,
),
] = nl.transpose(loaded[:, nl.ds(tb_i * B_D_SIZE, B_D_SIZE)])
# load value cache
for load_idx in nl.affine_range(num_loads):
loaded = nl.load(kv_cache[1, block_tables[load_idx, i_p,
large_k_tile_idx], i_f])
if cur_v_tile.dtype != loaded.dtype:
loaded = nl.copy(loaded, dtype=cur_v_tile.dtype)
i_p = nl.arange(B_P_SIZE)[:, None]
i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :]
cur_v_tile[
:,
nl.ds(
load_idx * tiled_block_size * B_D_SIZE,
tiled_block_size * B_D_SIZE,
),
] = loaded
@nki.jit
def transpose_p_local(p_local_transposed,
p_local,
LARGE_TILE_SZ,
B_F_SIZE=512):
for i in nl.affine_range(LARGE_TILE_SZ // B_F_SIZE):
if nisa.get_nc_version() == nisa.nc_version.gen3:
p_local_t_tmp = nl.ndarray((par_dim(128), B_F_SIZE),
buffer=nl.sbuf,
dtype=p_local.dtype)
else:
p_local_t_tmp = nl.ndarray((par_dim(128), B_F_SIZE),
buffer=nl.psum,
dtype=np.float32)
for j in nl.affine_range(B_F_SIZE // 128):
j_128_slice = nl.ds(j * 128, 128)
i_j_128_slice = nl.ds(i * B_F_SIZE + j * 128, 128)
if nisa.get_nc_version() == nisa.nc_version.gen3:
p_local_t_tmp[:, j_128_slice] = nisa.dma_transpose(
p_local[:, i_j_128_slice])
else:
p_local_t_tmp[:, j_128_slice] = nisa.nc_transpose(
p_local[:, i_j_128_slice])
p_local_transposed[:, nl.ds(i * B_F_SIZE, B_F_SIZE)] = nl.copy(
p_local_t_tmp, dtype=p_local_transposed.dtype)
@nki.jit
def _flash_attention_core(
q_local_tile,
k,
v,
o_buffer,
l_buffer,
m_buffer,
kernel_dtype,
acc_type,
tile_mask,
use_causal_mask,
q_tile_idx=None,
initialize=False,
LARGE_TILE_SZ=2048,
B_P_SIZE=128,
B_F_SIZE=512,
B_D_SIZE=128,
qk_res_buffer=None,
):
"""
The flash attention core function to calculate self attention between a tile
of q and a block of K and V.
The q_local_tile has (B_P_SIZE, B_D_SIZE)
The K and V have shape (B_D_SIZE, LARGE_TILE_SZ), whose free dimension will
be split into size B_F_SIZE tiles
The results are stored in the following three buffers
o_buffer: (B_P_SIZE, d)
l_buffer: (B_P_SIZE, 1)
m_buffer: (B_P_SIZE, 1)
All IO buffers are in SBUF.
"""
num_k_tile_per_large_tile = LARGE_TILE_SZ // B_F_SIZE
qk_res_buf = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
buffer=nl.sbuf,
dtype=acc_type)
max_local = nl.ndarray((par_dim(B_P_SIZE), num_k_tile_per_large_tile),
dtype=acc_type)
for k_i in nl.affine_range(num_k_tile_per_large_tile):
k_i_b_f_slice = nl.ds(k_i * B_F_SIZE, B_F_SIZE)
if use_causal_mask:
# mask are used to only apply computation to the lower half of the
# matrix, which reduce the arithmetic intensity by up to 50%
multiplication_required_selection = (q_tile_idx * B_P_SIZE
>= k_i * B_F_SIZE)
else:
multiplication_required_selection = True
if multiplication_required_selection:
qk_psum = nl.ndarray((par_dim(B_P_SIZE), B_F_SIZE),
dtype=np.float32,
buffer=nl.psum) # (128, 512)
qk_psum[:, :] = nl.matmul(q_local_tile,
k[:, k_i_b_f_slice],
transpose_x=True) # (p(128), 512)
qk_res_buf[:, k_i_b_f_slice] = nl.where(
tile_mask[:, k_i_b_f_slice],
qk_psum[:, nl.ds(0, B_F_SIZE)],
-9984.0,
dtype=acc_type,
)
else:
qk_res_buf[:, k_i_b_f_slice] = -9984.0
# Calculate max of the current tile
max_local[:, k_i] = nisa.tensor_reduce(
np.max,
qk_res_buf[:, k_i_b_f_slice],
axis=(1, ),
dtype=acc_type,
negate=False,
)
if qk_res_buffer is not None:
qk_res_buffer[:, :] = nl.copy(qk_res_buf[:, :])
max_ = nisa.tensor_reduce(
np.max,
max_local[:, :],
axis=(1, ),
dtype=acc_type,
negate=False,
)
o_previous_scaled = nl.ndarray((par_dim(B_P_SIZE), B_D_SIZE),
dtype=o_buffer.dtype)
if initialize:
m_buffer[:, 0] = nl.copy(max_)
m_current = max_
else:
m_previous = nl.copy(m_buffer[:, 0])
m_buffer[:, 0] = nl.maximum(m_previous, max_) # (128,1)
m_current = m_buffer[:, 0]
# Compute scaling factor
alpha = nisa.activation(
np.exp,
m_previous,
bias=-1 * m_current,
scale=1.0,
)
o_previous_scaled[...] = nl.multiply(o_buffer[:, :], alpha)
p_local = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
dtype=kernel_dtype)
REDUCTION_TILE = min(2048, LARGE_TILE_SZ // 2)
p_partial_sum = nl.ndarray(
(par_dim(B_P_SIZE), LARGE_TILE_SZ // REDUCTION_TILE),
dtype=acc_type,
)
for k_r_i in nl.affine_range(LARGE_TILE_SZ // REDUCTION_TILE):
k_r_i_reduce_slice = nl.ds(k_r_i * REDUCTION_TILE, REDUCTION_TILE)
# compute exp(qk - max)
# Compute partial row - tile sum of exp(qk - max))
# FIXME : Use activation accumulate to accumulate over k_r_i loop ?
p_local[:, k_r_i_reduce_slice] = nisa.activation_reduce(
np.exp,
qk_res_buf[:, k_r_i_reduce_slice],
bias=-1 * m_current,
scale=1.0,
reduce_op=nl.add,
reduce_res=p_partial_sum[:, k_r_i],
dtype=kernel_dtype,
)
ps = nl.sum(p_partial_sum, axis=1, dtype=acc_type)
p_local_transposed = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
dtype=kernel_dtype)
transpose_p_local(
p_local_transposed=p_local_transposed,
p_local=p_local,
LARGE_TILE_SZ=LARGE_TILE_SZ,
B_F_SIZE=B_F_SIZE,
)
pv_psum = nl.zeros(
(par_dim(B_P_SIZE), B_D_SIZE),
dtype=np.float32,
buffer=nl.psum,
)
for k_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE):
pv_psum[:, :] += nl.matmul(
p_local_transposed[:, nl.ds(k_i * B_P_SIZE, B_P_SIZE)],
v[:, nl.ds(k_i * B_D_SIZE, B_D_SIZE)],
transpose_x=True,
) # (128, 128) (p(Br), d)
if initialize:
o_buffer[:, :] = nl.copy(pv_psum[:, :])
l_buffer[:, 0] = nl.add(nl.log(ps), max_)
else:
o_buffer[:, :] = nl.add(o_previous_scaled, pv_psum)
l_prev = l_buffer[:, 0]
l_exp = nl.add(
nl.exp(nl.subtract(l_prev, m_current)),
ps,
)
l_buffer[:, 0] = nl.add(m_current, nl.log(l_exp))
@nki.jit
def load_v_tile(v_hbm_tile, cur_v_tile, large_tile_idx, v_i, LARGE_TILE_SZ):
B_P_SIZE = 128
B_D_SIZE = v_hbm_tile.shape[-1]
loaded = nl.load(v_hbm_tile[
nl.ds(large_tile_idx * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE),
:,
])
if cur_v_tile.dtype != loaded.dtype:
loaded = nl.copy(loaded, dtype=cur_v_tile.dtype)
cur_v_tile[:, nl.ds(v_i * B_D_SIZE, B_D_SIZE)] = loaded
@nki.jit
def flash_paged_attention(
query,
key,
value,
kv_cache,
block_tables,
mask,
softmax_scale=None,
mixed_precision=True,
LARGE_TILE_SZ=2048,
return_debug_tensors=False,
):
"""
Flash PagedAttention Forward Kernel.
IO tensor layouts:
- query: shape (1, n_heads, d, seq_q)
- key: shape (1, n_kv_heads, d, seq_k)
- value: shape (1, n_kv_heads, seq_v, d)
- kv_cache: (2, num_blocks, n_kv_heads, block_size, d)
- block_tables: (num_active_blocks, )
- mask: (seq_q, num_active_blocks * block_size + seq_q)
- o: shape (1, n_heads, seq_q, d)
- This kernel requires seq_k == seq_v
- We use continuous batching by default, so the batch dimension is
always 1, and different requests are concatenated along sequence
dimension.
- We use paged cache blocks (kv_cache) to store KV cache.
IO tensor dtypes:
- This kernel assumes all IO tensors have the same dtype except for
block_tables (int32) and mask (int32)
- If mixed_precision is True, then all Tensor Engine operation will be
performed in bfloat16 and accumulation will be performed in float32.
Otherwise the intermediates will be in the same type as the inputs.
Compile-time Constants:
- softmax_scale: scaling for softmax, is None, default is `1.0/(d**0.5)`
- mixed_precision: flag to set non-matmul ops in fp32 precision, default
is set to `true`, if false, we use same precision as input types
- LARGE_TILE_SZ: `default=2048`, size of the kv tile size for attention
computation reduction
GQA support Notes:
the spmd kernel for launching kernel should be on kv_heads instead of
nheads
Example usage:
MHA: q: [b, h, d, s], k: [b, h, d, s], v: [b, h, s, d]
usage: `flash_fwd[b, h](q, k, v, ...)`
GQA: q: [b, h, d, s], k: [b, kv_h, d, s], v: [b, kv_h, s, d]
usage: `flash_fwd[b, kv_h](q, k, v, ...)`
"""
B_F_SIZE = 512
B_P_SIZE = 128
b, h, d, seqlen_q = query.shape
B_D_SIZE = d
n_tile_q = seqlen_q // B_P_SIZE # since q will be loaded on tensor engine
_, num_blocks, k_h, block_size, _ = kv_cache.shape
q_h_per_k_h = h // k_h
assert b == 1, f"invalid batch size {b=}"
assert d <= 128, f" we do not support head_dim > 128, got head dim {d=}"
cache_shape = (2, num_blocks, k_h, block_size, d)
assert (tuple(kv_cache.shape) == cache_shape
), f"{kv_cache.shape=} mismatch, expect {cache_shape}"
assert key is None or tuple(key.shape) == (
1,
k_h,
d,
seqlen_q,
), f"key shape {key.shape} mismatch!"
assert value is None or tuple(value.shape) == (
1,
k_h,
seqlen_q,
d,
), f"value shape {value.shape} mismatch!"
assert (
nl.program_ndim() == 2
), f"Expect spmd grid with 2 dimensions, got {nl.program_ndim()} instead!"
batch_id = nl.program_id(axis=0)
head_id = nl.program_id(axis=1)
(num_active_blocks, ) = block_tables.shape
context_kv_len = num_active_blocks * block_size
assert (
LARGE_TILE_SZ % B_F_SIZE == 0
), f"Need {LARGE_TILE_SZ=} to be divisible by {B_F_SIZE=} in transpose_p"
assert (context_kv_len % LARGE_TILE_SZ == 0
), f"Need {context_kv_len=} to be divisible by {LARGE_TILE_SZ=}"
num_blocks_per_large_tile = LARGE_TILE_SZ // block_size
assert is_power_of_2(
num_blocks_per_large_tile
), f"{num_blocks_per_large_tile=} is expected of be power of 2"
if seqlen_q > B_F_SIZE:
MAX_REDUCTION_TILE = 2048
if seqlen_q // 2 > MAX_REDUCTION_TILE:
assert (
seqlen_q % MAX_REDUCTION_TILE == 0
), f"{seqlen_q=} should be divisible by {MAX_REDUCTION_TILE=}"
else:
assert (seqlen_q % B_F_SIZE == 0
), f"{seqlen_q=} should be divisible by {B_F_SIZE=})"
kernel_dtype = nl.bfloat16 if mixed_precision else query.dtype
acc_type = np.dtype(np.float32) if mixed_precision else kernel_dtype
softmax_scale = softmax_scale or (1.0 / (d**0.5))
num_large_k_tile = context_kv_len // LARGE_TILE_SZ
o = nl.ndarray((b, h, seqlen_q, d),
dtype=query.dtype,
buffer=nl.shared_hbm)
hbm_l_buffer, hbm_m_buffer, hbm_qk_res, qk_res_buffer = (
None,
None,
None,
None,
)
if return_debug_tensors:
hbm_l_buffer = nl.ndarray((b, h, seqlen_q),
dtype=acc_type,
buffer=nl.shared_hbm)
hbm_m_buffer = nl.ndarray((b, h, seqlen_q),
dtype=acc_type,
buffer=nl.shared_hbm)
hbm_qk_res = nl.ndarray((b, h, B_P_SIZE, seqlen_q),
dtype=acc_type,
buffer=nl.shared_hbm)
qk_res_buffer = nl.zeros(
(n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), seqlen_q),
dtype=acc_type,
buffer=nl.sbuf,
lazy_initialization=True,
)
block_tables_sbuf = load_block_tables(
block_tables_hbm=block_tables,
num_tiles=num_large_k_tile,
num_blocks_per_tile=num_blocks_per_large_tile,
)
# On Neuron, we need B_P_SIZE = 128 blocks to make DMA efficient
if num_blocks_per_large_tile < B_P_SIZE:
# we checked num_blocks_per_tile is a power of 2
assert B_P_SIZE % num_blocks_per_large_tile == 0
block_size_tiling_factor = B_P_SIZE // num_blocks_per_large_tile
# We assume block_size >= block_size_tiling_factor
assert block_size % block_size_tiling_factor == 0
else:
block_size_tiling_factor = 1
tiled_block_size = block_size // block_size_tiling_factor
# Indirect DMA load must be placed along Partition Dimension
block_tables_sbuf = transform_block_tables_for_indirect_load(
block_tables_sbuf,
block_size_tiling_factor=block_size_tiling_factor,
num_head=k_h,
head_id=head_id,
)
# Flatten KV cache to be 3D for loading into SBUF
new_cache_shape = (
2,
num_blocks * k_h * block_size_tiling_factor,
tiled_block_size * d,
)
kv_cache = kv_cache.reshape(new_cache_shape)
# Global Flash Attention accumulators
o_buffer = nl.zeros(
(n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), d),
dtype=acc_type,
buffer=nl.sbuf,
lazy_initialization=True,
)
l_buffer = nl.zeros(
(n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), 1),
dtype=acc_type,
buffer=nl.sbuf,
lazy_initialization=True,
)
m_buffer = nl.zeros(
(n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), 1),
dtype=acc_type,
buffer=nl.sbuf,
lazy_initialization=True,
)
for large_k_tile_idx in nl.sequential_range(0, num_large_k_tile):
num_loads = cdiv(num_blocks_per_large_tile, B_P_SIZE)
cur_k_tile = nl.ndarray(
(par_dim(B_D_SIZE), LARGE_TILE_SZ),
dtype=kernel_dtype,
)
cur_v_tile = nl.ndarray(
(par_dim(B_P_SIZE), num_loads * tiled_block_size * B_D_SIZE),
dtype=kernel_dtype,
)
load_kv_tile_from_cache(
cur_k_tile=cur_k_tile,
cur_v_tile=cur_v_tile,
kv_cache=kv_cache,
block_tables=block_tables_sbuf,
large_k_tile_idx=large_k_tile_idx,
num_blocks_per_large_tile=num_blocks_per_large_tile,
tiled_block_size=tiled_block_size,
B_P_SIZE=B_P_SIZE,
B_D_SIZE=B_D_SIZE,
)
for i in nl.affine_range(n_tile_q):
cur_mask = nl.load(mask[
nl.ds(i * B_P_SIZE, B_P_SIZE),
nl.ds(large_k_tile_idx * LARGE_TILE_SZ, LARGE_TILE_SZ),
])
for i_q_h in nl.affine_range(q_h_per_k_h):
q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype)
q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h]
q_sbuf_tile = nl.load(q_hbm_tile[:,
nl.ds(i *
B_P_SIZE, B_P_SIZE)])
if q_sbuf_tile.dtype != kernel_dtype:
q_sbuf_tile = nl.copy(q_sbuf_tile, dtype=kernel_dtype)
q_tile[:, :] = q_sbuf_tile * softmax_scale
_flash_attention_core(
q_local_tile=q_tile,
k=cur_k_tile,
v=cur_v_tile,
o_buffer=o_buffer[i, i_q_h],
l_buffer=l_buffer[i, i_q_h],
m_buffer=m_buffer[i, i_q_h],
kernel_dtype=kernel_dtype,
acc_type=acc_type,
tile_mask=cur_mask,
use_causal_mask=False,
q_tile_idx=i,
initialize=large_k_tile_idx == 0,
LARGE_TILE_SZ=LARGE_TILE_SZ,
B_P_SIZE=B_P_SIZE,
B_F_SIZE=B_F_SIZE,
B_D_SIZE=B_D_SIZE,
)
# compute attention between input query, key and value
if key is not None and value is not None:
B_F_SIZE = min(seqlen_q, B_F_SIZE)
LARGE_TILE_SZ = seqlen_q
cur_k_tile = nl.ndarray((par_dim(B_D_SIZE), LARGE_TILE_SZ),
dtype=kernel_dtype)
cur_v_tile = nl.ndarray(
(par_dim(B_P_SIZE), LARGE_TILE_SZ // B_P_SIZE * B_D_SIZE),
dtype=kernel_dtype,
)
loaded = nl.load(key[batch_id, head_id, :, :])
if loaded.dtype != kernel_dtype:
loaded = nl.copy(loaded, dtype=kernel_dtype)
cur_k_tile[:, :] = loaded
v_hbm_tile = value[batch_id, head_id]
for v_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE):
load_v_tile(
v_hbm_tile=v_hbm_tile,
cur_v_tile=cur_v_tile,
large_tile_idx=0,
v_i=v_i,
LARGE_TILE_SZ=LARGE_TILE_SZ,
)
for i in nl.affine_range(n_tile_q):
cur_mask = nl.load(mask[
nl.ds(i * B_P_SIZE, B_P_SIZE),
nl.ds(context_kv_len, LARGE_TILE_SZ),
])
for i_q_h in nl.affine_range(q_h_per_k_h):
q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype)
q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h]
q_sbuf_tile = nl.load(q_hbm_tile[:,
nl.ds(i *
B_P_SIZE, B_P_SIZE)])
if q_sbuf_tile.dtype != kernel_dtype:
q_sbuf_tile = nl.copy(q_sbuf_tile, dtype=kernel_dtype)
q_tile[:, :] = q_sbuf_tile * softmax_scale
_flash_attention_core(
q_local_tile=q_tile,
k=cur_k_tile,
v=cur_v_tile,
o_buffer=o_buffer[i, i_q_h],
l_buffer=l_buffer[i, i_q_h],
m_buffer=m_buffer[i, i_q_h],
kernel_dtype=kernel_dtype,
acc_type=acc_type,
tile_mask=cur_mask,
use_causal_mask=True,
q_tile_idx=i,
initialize=False,
LARGE_TILE_SZ=LARGE_TILE_SZ,
B_P_SIZE=B_P_SIZE,
B_F_SIZE=B_F_SIZE,
B_D_SIZE=B_D_SIZE,
qk_res_buffer=(qk_res_buffer[i, i_q_h]
if qk_res_buffer is not None else None),
)
# -- -- -- -- write output to buffer on HBM -- -- -- -- -- -- #
for i_q_h in nl.affine_range(q_h_per_k_h):
for i in nl.affine_range(n_tile_q):
out = nl.multiply(
o_buffer[i, i_q_h],
nl.exp(m_buffer[i, i_q_h] - l_buffer[i, i_q_h]),
dtype=kernel_dtype,
)
nl.store(
o[
batch_id,
head_id * q_h_per_k_h + i_q_h,
nl.ds(i * B_P_SIZE, B_P_SIZE),
:,
],
out,
)
# maximum and summation statistics
if return_debug_tensors:
nl.store(
hbm_m_buffer[
batch_id,
head_id * q_h_per_k_h + i_q_h,
nl.ds(i * B_P_SIZE, B_P_SIZE),
],
m_buffer[i, i_q_h, :, :],
)
nl.store(
hbm_l_buffer[
batch_id,
head_id * q_h_per_k_h + i_q_h,
nl.ds(i * B_P_SIZE, B_P_SIZE),
],
l_buffer[i, i_q_h],
)
nl.store(
hbm_qk_res[batch_id, head_id * q_h_per_k_h + i_q_h, :, :],
qk_res_buffer[batch_id, i_q_h, :, :],
)
if return_debug_tensors:
return o, hbm_m_buffer, hbm_l_buffer, hbm_qk_res
return o
def reorder_context_mask(mask, LARGE_TILE_SZ, block_size):
"""
Reorder the mask to make it compatible with the flash attention kernel.
We vectorize KV cache read to improve DMA utilization. However, the layout
that maximizes DMA bandwidth changes the order tokens are consumed.
The token layout (inner 2 dimensions) after vectorized load is (B_P_SIZE,
tiled_block_size) in a tile of `B_P_SIZE * tiled_block_size` tokens. And
each step the engine consumes a column (rather than a row) of B_P_SIZE
tokens. Therefore, the tokens are visited in a strided way.
To make sure mask matches the order tokens are consumed, we need to properly
transpose mask.
"""
total_query_len, total_seq_len = mask.shape
context_kv_len = total_seq_len - total_query_len
B_P_SIZE = 128
assert (LARGE_TILE_SZ
>= B_P_SIZE), f"{LARGE_TILE_SZ=} must be larger than {B_P_SIZE=}"
num_tiled_blocks = max(B_P_SIZE, LARGE_TILE_SZ // block_size)
tiled_block_size = LARGE_TILE_SZ // num_tiled_blocks
if tiled_block_size > 1:
# Mask reordering is needed when tiled_block_size > 1
device = mask.device
mask = mask.cpu()
context_mask = mask[:, :context_kv_len]
context_mask = context_mask.view(
total_query_len,
context_kv_len // LARGE_TILE_SZ,
num_tiled_blocks // B_P_SIZE,
B_P_SIZE,
tiled_block_size,
)
context_mask = context_mask.transpose(3, 4).reshape(
total_query_len, context_kv_len)
new_mask = mask[:, context_kv_len:]
return torch.concat([context_mask, new_mask], dim=1).to(device)
else:
return mask
def flash_attn_varlen_nkifunc(
query,
key,
value,
kv_cache,
block_table,
attn_mask,
n_kv_head=None,
head_size=None,
LARGE_TILE_SZ=2048,
mixed_precision=True,
):
"""
Compute flash paged attention for variable length sequences.
This function is a wrapper around the flash attention NKI kernel. It takes
in the following arguments:
- query: (1, n_heads, d, seq_q)
- key: (1, n_kv_heads, d, seq_k)
- value: (1, n_kv_heads, seq_v, d)
- kv_cache: (2, n_blocks, n_kv_heads, block_size, d)
- block_tables: (n_active_blocks, )
- attn_mask: (seq_q, n_active_blocks * block_size + seq_q)
Notes:
- attn_mask must be reordered outside using `reorder_context_mask`
- Key/value cache layout must be (n_blocks, n_kv_heads, block_size, d)
for better DMA throughput
"""
if n_kv_head is None:
n_kv_head = kv_cache.shape[2]
assert kv_cache.shape[0] == 2
assert kv_cache.shape[2] == n_kv_head
if head_size is None:
head_size = kv_cache.shape[-1]
kwargs = dict(
query=query,
key=key,
value=value,
kv_cache=kv_cache,
block_tables=block_table,
mask=attn_mask,
softmax_scale=1.0 / (head_size**0.5),
mixed_precision=mixed_precision,
LARGE_TILE_SZ=LARGE_TILE_SZ,
)
o = flash_paged_attention[1, n_kv_head](**kwargs)
return o
def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
) -> None:
"""
Writes key-value pairs to the KV cache at specified positions.
Args:
key (torch.Tensor): Key tensor with shape
(num_tokens, n_kv_head, d_head)
value (torch.Tensor): Value tensor with shape
(num_tokens, n_kv_head, d_head)
kv_cache (torch.Tensor): Key/value cache tensor with shape
(2, num_blocks, n_kv_head, block_size, d_head)
slot_mapping (torch.Tensor): Mapping tensor indicating cache positions
with shape (num_tokens)
Returns:
None: Updates the kv_cache tensor in-place
"""
block_size = kv_cache.size(3)
n_kv_head = key.size(1)
# Calculate indices with explicit floor division
block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
block_offsets = slot_mapping % block_size
# Create the head indices tensor
head_indices = torch.arange(n_kv_head, device=key.device)
# Update caches using index_put_
kv_cache.index_put_(
(torch.tensor([0], device=key.device), block_indices[:, None],
head_indices[None, :], block_offsets[:, None]), key)
kv_cache.index_put_(
(torch.tensor([1], device=key.device), block_indices[:, None],
head_indices[None, :], block_offsets[:, None]), value)
...@@ -54,7 +54,6 @@ SystemEnv = namedtuple( ...@@ -54,7 +54,6 @@ SystemEnv = namedtuple(
'is_xnnpack_available', 'is_xnnpack_available',
'cpu_info', 'cpu_info',
'rocm_version', # vllm specific field 'rocm_version', # vllm specific field
'neuron_sdk_version', # vllm specific field
'vllm_version', # vllm specific field 'vllm_version', # vllm specific field
'vllm_build_flags', # vllm specific field 'vllm_build_flags', # vllm specific field
'gpu_topo', # vllm specific field 'gpu_topo', # vllm specific field
...@@ -275,15 +274,6 @@ def get_rocm_version(run_lambda): ...@@ -275,15 +274,6 @@ def get_rocm_version(run_lambda):
r'HIP version: (\S+)') r'HIP version: (\S+)')
def get_neuron_sdk_version(run_lambda):
# Adapted from your install script
try:
result = run_lambda(["neuron-ls"])
return result if result[0] == 0 else 'N/A'
except Exception:
return 'N/A'
def get_vllm_version(): def get_vllm_version():
from vllm import __version__, __version_tuple__ from vllm import __version__, __version_tuple__
...@@ -306,10 +296,9 @@ def get_vllm_version(): ...@@ -306,10 +296,9 @@ def get_vllm_version():
def summarize_vllm_build_flags(): def summarize_vllm_build_flags():
# This could be a static method if the flags are constant, or dynamic if you need to check environment variables, etc. # This could be a static method if the flags are constant, or dynamic if you need to check environment variables, etc.
return 'CUDA Archs: {}; ROCm: {}; Neuron: {}'.format( return 'CUDA Archs: {}; ROCm: {}'.format(
os.environ.get('TORCH_CUDA_ARCH_LIST', 'Not Set'), os.environ.get('TORCH_CUDA_ARCH_LIST', 'Not Set'),
'Enabled' if os.environ.get('ROCM_HOME') else 'Disabled', 'Enabled' if os.environ.get('ROCM_HOME') else 'Disabled',
'Enabled' if os.environ.get('NEURON_CORES') else 'Disabled',
) )
...@@ -601,7 +590,6 @@ def get_env_info(): ...@@ -601,7 +590,6 @@ def get_env_info():
conda_packages = get_conda_packages(run_lambda) conda_packages = get_conda_packages(run_lambda)
rocm_version = get_rocm_version(run_lambda) rocm_version = get_rocm_version(run_lambda)
neuron_sdk_version = get_neuron_sdk_version(run_lambda)
vllm_version = get_vllm_version() vllm_version = get_vllm_version()
vllm_build_flags = summarize_vllm_build_flags() vllm_build_flags = summarize_vllm_build_flags()
gpu_topo = get_gpu_topo(run_lambda) gpu_topo = get_gpu_topo(run_lambda)
...@@ -635,7 +623,6 @@ def get_env_info(): ...@@ -635,7 +623,6 @@ def get_env_info():
is_xnnpack_available=is_xnnpack_available(), is_xnnpack_available=is_xnnpack_available(),
cpu_info=get_cpu_info(run_lambda), cpu_info=get_cpu_info(run_lambda),
rocm_version=rocm_version, rocm_version=rocm_version,
neuron_sdk_version=neuron_sdk_version,
vllm_version=vllm_version, vllm_version=vllm_version,
vllm_build_flags=vllm_build_flags, vllm_build_flags=vllm_build_flags,
gpu_topo=gpu_topo, gpu_topo=gpu_topo,
...@@ -702,7 +689,6 @@ env_info_fmt += """ ...@@ -702,7 +689,6 @@ env_info_fmt += """
vLLM Info vLLM Info
============================== ==============================
ROCM Version : {rocm_version} ROCM Version : {rocm_version}
Neuron SDK Version : {neuron_sdk_version}
vLLM Version : {vllm_version} vLLM Version : {vllm_version}
vLLM Build Flags: vLLM Build Flags:
{vllm_build_flags} {vllm_build_flags}
......
...@@ -461,11 +461,6 @@ class ModelConfig: ...@@ -461,11 +461,6 @@ class ModelConfig:
DP (which is controlled by `--data-parallel-size`). DP (which is controlled by `--data-parallel-size`).
This is only supported on a per-model basis and falls back to This is only supported on a per-model basis and falls back to
`"weights"` if the encoder does not support DP.""" `"weights"` if the encoder does not support DP."""
override_neuron_config: dict[str, Any] = field(default_factory=dict)
"""Initialize non-default neuron config or override default neuron config
that are specific to Neuron devices, this argument will be used to
configure the neuron config that can not be gathered from the vllm
arguments. e.g. `{"cast_logits_dtype": "bfloat16"}`."""
pooler_config: Optional["PoolerConfig"] = field(init=False) pooler_config: Optional["PoolerConfig"] = field(init=False)
"""Pooler config which controls the behaviour of output pooling in pooling """Pooler config which controls the behaviour of output pooling in pooling
models.""" models."""
...@@ -785,10 +780,6 @@ class ModelConfig: ...@@ -785,10 +780,6 @@ class ModelConfig:
if not self.skip_tokenizer_init: if not self.skip_tokenizer_init:
self._verify_tokenizer_mode() self._verify_tokenizer_mode()
if (not current_platform.is_neuron() and self.override_neuron_config):
raise ValueError(
"`override_neuron_config` is only supported on Neuron.")
# Avoid running try_verify_and_update_config multiple times # Avoid running try_verify_and_update_config multiple times
self.config_updated = False self.config_updated = False
...@@ -1696,13 +1687,7 @@ class ModelConfig: ...@@ -1696,13 +1687,7 @@ class ModelConfig:
""" """
For Mllama, VLLM overrides HF's is_encoder_decoder flag and sets it to For Mllama, VLLM overrides HF's is_encoder_decoder flag and sets it to
True to enable cross-attention True to enable cross-attention
Neuron needs all multimodal data to be in the decoder and does not
need to explicitly enable cross-attention
""" """
if (current_platform.is_neuron()
and self.hf_config.model_type == "mllama"):
return False
return is_encoder_decoder(self.hf_config) return is_encoder_decoder(self.hf_config)
@property @property
...@@ -1871,7 +1856,7 @@ class LoadConfig: ...@@ -1871,7 +1856,7 @@ class LoadConfig:
self.ignore_patterns = ["original/**/*"] self.ignore_patterns = ["original/**/*"]
Device = Literal["auto", "cuda", "neuron", "cpu", "tpu", "xpu"] Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"]
@config @config
...@@ -1927,9 +1912,7 @@ class DeviceConfig: ...@@ -1927,9 +1912,7 @@ class DeviceConfig:
self.device_type = self.device.type self.device_type = self.device.type
# Some device types require processing inputs on CPU # Some device types require processing inputs on CPU
if self.device_type in ["neuron"]: if self.device_type in ["tpu"]:
self.device = torch.device("cpu")
elif self.device_type in ["tpu"]:
self.device = None self.device = None
else: else:
# Set device with device type # Set device with device type
...@@ -3941,7 +3924,6 @@ class VllmConfig: ...@@ -3941,7 +3924,6 @@ class VllmConfig:
f"skip_tokenizer_init={self.model_config.skip_tokenizer_init}, " f"skip_tokenizer_init={self.model_config.skip_tokenizer_init}, "
f"tokenizer_mode={self.model_config.tokenizer_mode}, " f"tokenizer_mode={self.model_config.tokenizer_mode}, "
f"revision={self.model_config.revision}, " f"revision={self.model_config.revision}, "
f"override_neuron_config={self.model_config.override_neuron_config}, " # noqa
f"tokenizer_revision={self.model_config.tokenizer_revision}, " f"tokenizer_revision={self.model_config.tokenizer_revision}, "
f"trust_remote_code={self.model_config.trust_remote_code}, " f"trust_remote_code={self.model_config.trust_remote_code}, "
f"dtype={self.model_config.dtype}, " f"dtype={self.model_config.dtype}, "
......
...@@ -33,9 +33,8 @@ class CacheConfig: ...@@ -33,9 +33,8 @@ class CacheConfig:
"""Configuration for the KV cache.""" """Configuration for the KV cache."""
block_size: SkipValidation[BlockSize] = None # type: ignore block_size: SkipValidation[BlockSize] = None # type: ignore
"""Size of a contiguous cache block in number of tokens. This is ignored on """Size of a contiguous cache block in number of tokens. On CUDA devices,
neuron devices and set to `--max-model-len`. On CUDA devices, only block only block sizes up to 32 are supported.
sizes up to 32 are supported. On HPU devices, block size defaults to 128.
This config has no static default. If left unspecified by the user, it will This config has no static default. If left unspecified by the user, it will
be set in `Platform.check_and_update_config()` based on the current be set in `Platform.check_and_update_config()` based on the current
......
...@@ -377,10 +377,7 @@ class ParallelConfig: ...@@ -377,10 +377,7 @@ class ParallelConfig:
from vllm.executor import ray_utils from vllm.executor import ray_utils
backend: DistributedExecutorBackend = "mp" backend: DistributedExecutorBackend = "mp"
ray_found = ray_utils.ray_is_available() ray_found = ray_utils.ray_is_available()
if current_platform.is_neuron(): if current_platform.is_tpu() and envs.VLLM_XLA_USE_SPMD:
# neuron uses single process to control multiple devices
backend = "uni"
elif current_platform.is_tpu() and envs.VLLM_XLA_USE_SPMD:
backend = "uni" backend = "uni"
elif (current_platform.is_cuda() elif (current_platform.is_cuda()
and cuda_device_count_stateless() < self.world_size): and cuda_device_count_stateless() < self.world_size):
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.distributed.device_communicators.base_device_communicator import (
DeviceCommunicatorBase)
from vllm.platforms import current_platform
if current_platform.is_neuron():
import torch_xla.core.xla_model as xm
class NeuronCommunicator(DeviceCommunicatorBase):
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
return xm.all_reduce(xm.REDUCE_SUM, x)
def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
assert dim == -1, "Neuron only supports dim=-1 for all-gather."
return xm.all_gather(x, dim=dim)
...@@ -419,8 +419,6 @@ class EngineArgs: ...@@ -419,8 +419,6 @@ class EngineArgs:
scheduling_policy: SchedulerPolicy = SchedulerConfig.policy scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls
override_neuron_config: dict[str, Any] = \
get_field(ModelConfig, "override_neuron_config")
override_pooler_config: Optional[Union[dict, PoolerConfig]] = \ override_pooler_config: Optional[Union[dict, PoolerConfig]] = \
ModelConfig.override_pooler_config ModelConfig.override_pooler_config
compilation_config: CompilationConfig = \ compilation_config: CompilationConfig = \
...@@ -561,8 +559,6 @@ class EngineArgs: ...@@ -561,8 +559,6 @@ class EngineArgs:
help=model_kwargs["hf_token"]["help"]) help=model_kwargs["hf_token"]["help"])
model_group.add_argument("--hf-overrides", model_group.add_argument("--hf-overrides",
**model_kwargs["hf_overrides"]) **model_kwargs["hf_overrides"])
model_group.add_argument("--override-neuron-config",
**model_kwargs["override_neuron_config"])
model_group.add_argument("--override-pooler-config", model_group.add_argument("--override-pooler-config",
**model_kwargs["override_pooler_config"]) **model_kwargs["override_pooler_config"])
model_group.add_argument("--logits-processor-pattern", model_group.add_argument("--logits-processor-pattern",
...@@ -992,7 +988,6 @@ class EngineArgs: ...@@ -992,7 +988,6 @@ class EngineArgs:
mm_processor_kwargs=self.mm_processor_kwargs, mm_processor_kwargs=self.mm_processor_kwargs,
mm_processor_cache_gb=self.mm_processor_cache_gb, mm_processor_cache_gb=self.mm_processor_cache_gb,
mm_encoder_tp_mode=self.mm_encoder_tp_mode, mm_encoder_tp_mode=self.mm_encoder_tp_mode,
override_neuron_config=self.override_neuron_config,
override_pooler_config=self.override_pooler_config, override_pooler_config=self.override_pooler_config,
logits_processor_pattern=self.logits_processor_pattern, logits_processor_pattern=self.logits_processor_pattern,
generation_config=self.generation_config, generation_config=self.generation_config,
......
...@@ -236,7 +236,7 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -236,7 +236,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# ================== Installation Time Env Vars ================== # ================== Installation Time Env Vars ==================
# Target device of vLLM, supporting [cuda (by default), # Target device of vLLM, supporting [cuda (by default),
# rocm, neuron, cpu] # rocm, cpu]
"VLLM_TARGET_DEVICE": "VLLM_TARGET_DEVICE":
lambda: os.getenv("VLLM_TARGET_DEVICE", "cuda").lower(), lambda: os.getenv("VLLM_TARGET_DEVICE", "cuda").lower(),
......
...@@ -73,11 +73,6 @@ class CustomOp(nn.Module): ...@@ -73,11 +73,6 @@ class CustomOp(nn.Module):
# NOTE(woosuk): This is a placeholder for future extensions. # NOTE(woosuk): This is a placeholder for future extensions.
return self.forward_native(*args, **kwargs) return self.forward_native(*args, **kwargs)
def forward_neuron(self, *args, **kwargs):
# By default, we assume that Neuron ops are compatible with the
# PyTorch-native implementation.
return self.forward_native(*args, **kwargs)
def forward_oot(self, *args, **kwargs): def forward_oot(self, *args, **kwargs):
# By default, we assume that OOT ops are compatible with the # By default, we assume that OOT ops are compatible with the
# PyTorch-native implementation. # PyTorch-native implementation.
...@@ -105,8 +100,6 @@ class CustomOp(nn.Module): ...@@ -105,8 +100,6 @@ class CustomOp(nn.Module):
return self.forward_tpu return self.forward_tpu
elif current_platform.is_xpu(): elif current_platform.is_xpu():
return self.forward_xpu return self.forward_xpu
elif current_platform.is_neuron():
return self.forward_neuron
elif current_platform.is_out_of_tree(): elif current_platform.is_out_of_tree():
return self.forward_oot return self.forward_oot
else: else:
......
...@@ -95,13 +95,6 @@ class SiluAndMul(CustomOp): ...@@ -95,13 +95,6 @@ class SiluAndMul(CustomOp):
self.op(out, x) self.op(out, x)
return out return out
def forward_neuron(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
x_reshaped = x.view(-1, x.shape[-1])
s = x_reshaped[:, :d] * F.sigmoid(x_reshaped[:, :d])
result = s * x_reshaped[:, d:]
return result.view(*x.shape[:-1], d)
@CustomOp.register("mul_and_silu") @CustomOp.register("mul_and_silu")
class MulAndSilu(CustomOp): class MulAndSilu(CustomOp):
......
...@@ -26,7 +26,6 @@ QuantizationMethods = Literal[ ...@@ -26,7 +26,6 @@ QuantizationMethods = Literal[
"bitsandbytes", "bitsandbytes",
"hqq", "hqq",
"experts_int8", "experts_int8",
"neuron_quant",
"ipex", "ipex",
"quark", "quark",
"moe_wna16", "moe_wna16",
...@@ -108,7 +107,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: ...@@ -108,7 +107,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
from .modelopt import ModelOptFp8Config, ModelOptNvFp4Config from .modelopt import ModelOptFp8Config, ModelOptNvFp4Config
from .moe_wna16 import MoeWNA16Config from .moe_wna16 import MoeWNA16Config
from .mxfp4 import Mxfp4Config from .mxfp4 import Mxfp4Config
from .neuron_quant import NeuronQuantConfig
from .petit import PetitNvFp4Config from .petit import PetitNvFp4Config
from .ptpc_fp8 import PTPCFp8Config from .ptpc_fp8 import PTPCFp8Config
from .rtn import RTNConfig from .rtn import RTNConfig
...@@ -135,7 +133,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: ...@@ -135,7 +133,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"ptpc_fp8": PTPCFp8Config, "ptpc_fp8": PTPCFp8Config,
"hqq": HQQMarlinConfig, "hqq": HQQMarlinConfig,
"experts_int8": ExpertsInt8Config, "experts_int8": ExpertsInt8Config,
"neuron_quant": NeuronQuantConfig,
"ipex": IPEXConfig, "ipex": IPEXConfig,
"quark": QuarkConfig, "quark": QuarkConfig,
"moe_wna16": MoeWNA16Config, "moe_wna16": MoeWNA16Config,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from importlib.util import find_spec
from typing import Any, Optional
from torch.nn import Module
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
SUPPORTED_QUANT_DTYPE_LIST = ['s8', 'f8e4m3fn']
class AlwaysSupportedDtypes(list):
def __contains__(self, item):
return True
class NeuronQuantConfig(QuantizationConfig):
"""Int8 Quantization Config class for Neuron Backend."""
def __init__(
self,
dequant_dtype: str = "f16",
quantize_method: str = "vector_dynamic",
) -> None:
super().__init__()
self.quant_dtype = os.getenv("NEURON_QUANT_DTYPE", "s8")
if self.quant_dtype not in SUPPORTED_QUANT_DTYPE_LIST:
raise ValueError(
f"Neuron quantization datatype {self.quant_dtype} is not valid,"
f" the quantization datatype should match one of the below "
f"types {SUPPORTED_QUANT_DTYPE_LIST}")
self.dequant_dtype = dequant_dtype
self.quantize_method = quantize_method
def get_name(self) -> QuantizationMethods:
return "neuron_quant"
def get_supported_act_dtypes(self) -> list[str]:
# Neuron implements custom handling logic for quantization support
return AlwaysSupportedDtypes()
@classmethod
def get_min_capability(cls) -> int:
raise NotImplementedError(
"This function should not be called with Neuron Backend")
@staticmethod
def get_config_filenames() -> list[str]:
return []
@classmethod
def from_config(cls, config: dict[str, Any]) -> "NeuronQuantConfig":
quantize_method = cls.get_from_keys(config, ["quantize_method"])
dequant_dtype = cls.get_from_keys(config, ["dequant_dtype"])
return cls(dequant_dtype=dequant_dtype,
quantize_method=quantize_method)
def get_quant_method(self, layer: Module, prefix: str) -> Optional[Any]:
if find_spec("transformers_neuronx") is not None:
return self.get_quantization_config()
else:
raise NotImplementedError(
"Neuron Quantization is only supported through"
" transformers_neuronx.")
def get_quantization_config(self):
from transformers_neuronx.config import QuantizationConfig
return QuantizationConfig(quant_dtype=self.quant_dtype,
dequant_dtype=self.dequant_dtype,
quantize_method=self.quantize_method)
...@@ -7,7 +7,7 @@ import torch ...@@ -7,7 +7,7 @@ import torch
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from .common import apply_rotary_emb_dispatch, apply_rotary_emb_torch from .common import apply_rotary_emb_torch
@CustomOp.register("rotary_embedding") @CustomOp.register("rotary_embedding")
...@@ -149,87 +149,6 @@ class RotaryEmbedding(CustomOp): ...@@ -149,87 +149,6 @@ class RotaryEmbedding(CustomOp):
self.cos_sin_cache, self.is_neox_style) self.cos_sin_cache, self.is_neox_style)
return query, key return query, key
def forward_neuron(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
def _apply_rotary_emb_neuron(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
is_neox_style: bool,
) -> torch.Tensor:
cos = cos.unsqueeze(-2).to(x.dtype)
sin = sin.unsqueeze(-2).to(x.dtype)
if is_neox_style:
x1, x2 = torch.chunk(x, 2, dim=-1)
else:
# x1 = x[..., ::2]
# x2 = x[..., 1::2]
d = x.shape[-1] // 2
x_reshaped = x.view(-1, x.shape[-1])
x1 = x_reshaped[:, ::2].view(*x.shape[:-1], d)
x2 = x_reshaped[:, 1::2].view(*x.shape[:-1], d)
o1 = x1 * cos - x2 * sin
o2 = x2 * cos + x1 * sin
if is_neox_style:
return torch.cat((o1, o2), dim=-1)
else:
return torch.stack((o1, o2), dim=-1).flatten(-2)
if offsets is not None:
positions = positions + offsets
self.cos_sin_cache = self.cos_sin_cache.to(query.device,
dtype=query.dtype)
positions = positions.flatten()
num_tokens = positions.shape[0]
cos_sin = self.cos_sin_cache.index_select(0, positions)
cos, sin = cos_sin.chunk(2, dim=-1)
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
if key is not None:
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
if self.rotary_dim == self.head_size:
query = apply_rotary_emb_dispatch(query, cos, sin,
self.is_neox_style)
query = query.reshape(query_shape)
if key is not None:
key = apply_rotary_emb_dispatch(key, cos, sin,
self.is_neox_style)
key = key.reshape(key_shape)
else:
head_size = query.shape[-1]
query_reshaped = query.view(-1, head_size)
query_pass = query_reshaped[:, self.rotary_dim:].view(
*query.shape[:-1], head_size - self.rotary_dim)
query_rot = query_reshaped[:, :self.rotary_dim].view(
*query.shape[:-1], self.rotary_dim)
query_rot = _apply_rotary_emb_neuron(query_rot, cos, sin,
self.is_neox_style)
query = torch.cat((query_rot, query_pass),
dim=-1).reshape(query_shape)
if key is not None:
key_reshaped = key.view(-1, head_size)
key_pass = key_reshaped[:, self.rotary_dim:].view(
*key.shape[:-1], head_size - self.rotary_dim)
key_rot = key_reshaped[:, :self.rotary_dim].view(
*key.shape[:-1], self.rotary_dim)
key_rot = _apply_rotary_emb_neuron(key_rot, cos, sin,
self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
def extra_repr(self) -> str: def extra_repr(self) -> str:
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
s += f", max_position_embeddings={self.max_position_embeddings}" s += f", max_position_embeddings={self.max_position_embeddings}"
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Utilities for selecting and loading Neuron models in transformers-neuronx
framework."""
import ast
import copy
import importlib
import os
from typing import Optional
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.logprobs import Logprob
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import get_quantization_config
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import CompletionSequenceGroupOutput, SequenceOutput
TORCH_DTYPE_TO_NEURON_AMP = {
"auto": "f32",
"half": "f16",
"float16": "f16",
"bfloat16": "bf16",
"float": "f32",
"float32": "f32",
torch.float16: "f16",
torch.bfloat16: "bf16",
torch.float32: "f32",
}
# Models supported by Neuron.
_NEURON_SUPPORTED_MODELS: dict[str, tuple[str, str, str]] = {
"LlamaForCausalLM": ("transformers_neuronx.llama.model",
"LlamaForSampling", "LlamaForCausalLM"),
"MistralForCausalLM": ("transformers_neuronx.mistral.model",
"MistralForSampling", "MistralForCausalLM")
}
class NeuronCausalLM(nn.Module):
def __init__(self,
config: PretrainedConfig,
on_device_sampling_disabled: bool = False) -> None:
super().__init__()
self.config = config
self.logits_processor = LogitsProcessor(config.vocab_size,
logits_as_input=True)
self.on_device_sampling_disabled = on_device_sampling_disabled
if self.on_device_sampling_disabled:
# Use default sampler
self.sampler = Sampler()
# Lazy initialized
self.model: nn.Module
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_block_ids: torch.Tensor,
) -> torch.Tensor:
logits = self.model(input_ids,
cache_ids=positions,
start_ids=input_block_ids)
return logits
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(None, hidden_states, sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
if self.on_device_sampling_disabled:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
# On-device sampling outputs the token ids directly.
sampled_token_ids = logits.flatten()
next_tokens = []
sample_idx = 0
for seq_group in sampling_metadata.seq_groups:
samples = []
for seq_id in seq_group.seq_ids:
token_id = sampled_token_ids[sample_idx].item()
samples.append(
SequenceOutput(parent_seq_id=seq_id,
output_token=token_id,
logprobs={token_id: Logprob(token_id)}))
sample_idx += 1
next_tokens.append(
CompletionSequenceGroupOutput(samples=samples,
prompt_logprobs=None))
return SamplerOutput(outputs=next_tokens)
def load_weights(self, model_name_or_path: str, **kwargs):
arch = _get_model_architecture(self.config)
neuronx_module_path, neuronx_model_cls_name, hf_model_cls_name = (
_NEURON_SUPPORTED_MODELS[arch])
neuronx_module = importlib.import_module(neuronx_module_path)
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)
self.model = neuronx_model_cls.from_pretrained(model_name_or_path,
**kwargs)
self.model.to_neuron()
class NeuronSpeculationCausalLM(nn.Module):
"""A Neuron-optimized causal language model with speculative decoding."""
SPECULATION_TERMINATION_ID = -1
def __init__(self, speculation_model) -> None:
super().__init__()
self.model = speculation_model
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_block_ids: torch.Tensor,
) -> torch.Tensor:
tokens, counts = self.model.speculative_iteration(
input_ids, positions, input_block_ids)
# Mark the end of accepted speculative tokens for each sequence with the
# speculation termination id.
batch_size, steps = tokens.shape
mask = torch.arange(steps).expand(batch_size, -1) >= counts
tokens[mask] = self.SPECULATION_TERMINATION_ID
return tokens
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[list[SamplerOutput]]:
batch_size, num_steps = logits.shape
seq_ids = [
seq_id for sg in sampling_metadata.seq_groups
for seq_id in sg.seq_ids
]
# Organize input tensors by step instead of by sequence.
accepted_token_ids_by_step = logits.transpose(0, 1)
accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
sampler_output_list = []
for step_index in range(num_steps):
if all(token_id == self.SPECULATION_TERMINATION_ID
for token_id in accepted_token_ids_by_step[step_index]):
break
step_output_token_ids = []
for sequence_index in range(batch_size):
token_id = accepted_token_ids_by_step[step_index][
sequence_index]
step_output_token_ids.append(
CompletionSequenceGroupOutput(samples=[
SequenceOutput(parent_seq_id=seq_ids[sequence_index],
output_token=token_id,
logprobs={token_id: Logprob(token_id)})
],
prompt_logprobs=None))
sampler_output_list.append(
SamplerOutput(outputs=step_output_token_ids))
return sampler_output_list
def _get_model_architecture(config: PretrainedConfig) -> str:
architectures = getattr(config, "architectures", [])
for arch in architectures:
if arch in _NEURON_SUPPORTED_MODELS:
return arch
raise ValueError(
f"Model architectures {architectures} are not supported on Neuron "
f"for now. Supported architectures: "
f"{list(_NEURON_SUPPORTED_MODELS.keys())}")
def _get_buckets(env: str, default_value: list[int]) -> list[int]:
env_value = os.getenv(env)
if env_value is None:
return default_value
buckets_remove_empty = filter(
lambda x: x is not None and len(x.strip()) > 0, env_value.split(","))
buckets_int = map(int, buckets_remove_empty)
buckets_list = list(buckets_int)
return buckets_list
def _get_default_neuron_config(model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig):
"""Generate a neuron config based on vllm config args."""
from transformers_neuronx.config import ContinuousBatchingConfig
from transformers_neuronx.constants import LAYOUT_BSH
continuous_batching_config = ContinuousBatchingConfig(
batch_size_for_shared_caches=scheduler_config.max_num_seqs)
quant_config = dict(
dequant_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
quantize_method="vector_dynamic")
neuron_quantization_config_builder = lambda quant: get_quantization_config(
quant).from_config(quant_config).get_quant_method(None, "")
# TODO: Add Paged attention config to the default neuron arguments.
default_neuron_args = dict(
collectives_layout=LAYOUT_BSH,
attention_layout=LAYOUT_BSH,
fuse_qkv=True,
quant=neuron_quantization_config_builder(model_config.quantization)
if model_config.quantization else None,
continuous_batching=continuous_batching_config,
weight_tiling=bool(model_config.quantization),
on_device_generation=_get_neuron_on_device_generation_config(
model_config))
return default_neuron_args
def _get_default_neuron_config_for_speculation(
model_config: ModelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig):
"""Generate a neuron config for speculative decoding based on
vllm config args."""
from transformers_neuronx.config import ContinuousBatchingConfig
from transformers_neuronx.constants import LAYOUT_BSH
continuous_batching_config = ContinuousBatchingConfig(
batch_size_for_shared_caches=scheduler_config.max_num_seqs)
default_neuron_args = dict(collectives_layout=LAYOUT_BSH,
attention_layout=LAYOUT_BSH,
fuse_qkv=True,
on_device_embedding=True,
continuous_batching=continuous_batching_config,
on_device_generation=copy.deepcopy(
model_config.neuron_sampling_params))
return default_neuron_args
def _get_neuron_on_device_generation_config(model_config: ModelConfig):
if not _is_neuron_on_device_sampling_disabled(model_config):
return copy.deepcopy(model_config.neuron_sampling_params)
return None
def _is_neuron_on_device_sampling_disabled(model_config: ModelConfig) -> bool:
return not getattr(model_config, "neuron_sampling_params", None)
def _get_neuron_config_after_override(default_neuron_config,
overridden_neuron_config):
from transformers_neuronx.config import (ContinuousBatchingConfig,
GenerationConfig,
KVCacheQuantizationConfig,
NeuronConfig, QuantizationConfig,
SparseAttnConfig)
sparse_attn = overridden_neuron_config.pop("sparse_attn", {})
if sparse_attn:
overridden_neuron_config["sparse_attn"] = SparseAttnConfig(
**sparse_attn)
kv_cache_quant = overridden_neuron_config.pop("kv_cache_quant", {})
if kv_cache_quant:
overridden_neuron_config["kv_cache_quant"] = KVCacheQuantizationConfig(
**kv_cache_quant)
continuous_batching = overridden_neuron_config.pop("continuous_batching",
{})
if continuous_batching:
overridden_neuron_config[
"continuous_batching"] = ContinuousBatchingConfig(
**continuous_batching)
quant = overridden_neuron_config.pop("quant", {})
if quant:
overridden_neuron_config["quant"] = QuantizationConfig(**quant)
on_device_generation = overridden_neuron_config.pop(
"on_device_generation", {})
if on_device_generation:
overridden_neuron_config["on_device_generation"] = GenerationConfig(
**on_device_generation)
default_neuron_config.update(overridden_neuron_config)
return NeuronConfig(**default_neuron_config)
def get_neuron_model(model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig) -> nn.Module:
"""Initializes a neuron-optimized model for inference."""
# Create a model instance.
model = NeuronCausalLM(
model_config.hf_config,
_is_neuron_on_device_sampling_disabled(model_config))
default_neuron_config_args = _get_default_neuron_config(
model_config, parallel_config, scheduler_config)
neuron_config = _get_neuron_config_after_override(
default_neuron_config_args, model_config.override_neuron_config)
context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS",
[scheduler_config.max_model_len])
n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS",
[scheduler_config.max_model_len])
model.load_weights(model_config.model,
tp_degree=parallel_config.tensor_parallel_size,
amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
neuron_config=neuron_config,
context_length_estimate=context_length_estimates,
n_positions=n_positions,
batch_size=scheduler_config.max_num_seqs)
return model.eval()
def get_neuron_speculation_model(model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
speculation_config: SpeculativeConfig):
"""Initializes a neuron-optimized speculation model for inference.
This method is only applicable for speculation with a standalone draft model
"""
from transformers_neuronx.fused_speculation import FusedSpeculativeDecoder
# For Eagle SD, we need to pass in additional parameters in neuron config.
is_eagle = getattr(speculation_config.draft_model_config.hf_config,
"is_eagle", False)
# Create target model instance.
target_model = NeuronCausalLM(model_config.hf_config)
default_neuron_config_args = _get_default_neuron_config_for_speculation(
model_config, parallel_config, scheduler_config)
if is_eagle:
default_neuron_config_args['is_eagle_target'] = True
neuron_config = _get_neuron_config_after_override(
default_neuron_config_args, model_config.override_neuron_config)
context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS",
[scheduler_config.max_model_len])
n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS",
[scheduler_config.max_model_len])
target_model.load_weights(
model_config.model,
tp_degree=parallel_config.tensor_parallel_size,
amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
neuron_config=neuron_config,
context_length_estimate=context_length_estimates,
n_positions=n_positions,
batch_size=scheduler_config.max_num_seqs)
target_model.eval()
# Create draft model instance.
draft_model = NeuronCausalLM(
speculation_config.draft_model_config.hf_config)
default_draft_neuron_config_args = (
_get_default_neuron_config_for_speculation(
speculation_config.draft_model_config, parallel_config,
scheduler_config))
if is_eagle:
default_draft_neuron_config_args['is_eagle_draft'] = True
default_draft_neuron_config_args['has_pre_attention_norm'] = False
draft_neuron_config = _get_neuron_config_after_override(
default_draft_neuron_config_args,
speculation_config.draft_model_config.override_neuron_config)
draft_model.load_weights(speculation_config.draft_model_config.model,
tp_degree=speculation_config.
draft_parallel_config.tensor_parallel_size,
amp=TORCH_DTYPE_TO_NEURON_AMP[
speculation_config.draft_model_config.dtype],
neuron_config=draft_neuron_config,
context_length_estimate=context_length_estimates,
n_positions=n_positions,
batch_size=scheduler_config.max_num_seqs)
draft_model.eval()
num_speculative_tokens = speculation_config.num_speculative_tokens
# Create speculation model instance.
speculation_model = FusedSpeculativeDecoder(draft_model.model,
target_model.model,
num_speculative_tokens)
speculation_model.to_neuron()
return NeuronSpeculationCausalLM(speculation_model)
def get_neuron_eagle_speculation_model(model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
speculation_config: SpeculativeConfig):
"""Initializes a neuron-optimized EAGLE speculation model for inference."""
from transformers_neuronx.eagle_speculation import EagleSpeculativeDecoder
# Create target model instance.
target_model = NeuronCausalLM(model_config.hf_config)
default_neuron_config_args = _get_default_neuron_config_for_speculation(
model_config, parallel_config, scheduler_config)
default_neuron_config_args['is_eagle_target'] = True
neuron_config = _get_neuron_config_after_override(
default_neuron_config_args, model_config.override_neuron_config)
context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS",
[scheduler_config.max_model_len])
n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS",
[scheduler_config.max_model_len])
target_model.load_weights(
model_config.model,
tp_degree=parallel_config.tensor_parallel_size,
amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
neuron_config=neuron_config,
context_length_estimate=context_length_estimates,
n_positions=n_positions,
batch_size=scheduler_config.max_num_seqs)
target_model.eval()
# Create draft model instance.
draft_model = NeuronCausalLM(
speculation_config.draft_model_config.hf_config)
default_draft_neuron_config_args = (
_get_default_neuron_config_for_speculation(
speculation_config.draft_model_config, parallel_config,
scheduler_config))
default_draft_neuron_config_args['is_eagle_draft'] = True
default_draft_neuron_config_args['has_pre_attention_norm'] = False
draft_neuron_config = _get_neuron_config_after_override(
default_draft_neuron_config_args,
speculation_config.draft_model_config.override_neuron_config)
draft_model.load_weights(speculation_config.draft_model_config.model,
tp_degree=speculation_config.
draft_parallel_config.tensor_parallel_size,
amp=TORCH_DTYPE_TO_NEURON_AMP[
speculation_config.draft_model_config.dtype],
neuron_config=draft_neuron_config,
context_length_estimate=context_length_estimates,
n_positions=n_positions,
batch_size=scheduler_config.max_num_seqs)
draft_model.eval()
token_tree: dict[int, list[int]] = ast.literal_eval(
speculation_config.speculative_token_tree)
speculation_model = EagleSpeculativeDecoder(draft_model.model,
target_model.model,
token_tree=token_tree)
speculation_model.to_neuron()
return NeuronSpeculationCausalLM(speculation_model)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Utilities for selecting and loading Neuron models in
neuronx-distributed-inference framework."""
# Disabling yapf because yapf and isort have conflicts for the below imports
# yapf: disable
import copy
import hashlib
import importlib
import multiprocessing
import os
import shutil
from typing import Optional
import torch
import torch.nn as nn
from neuronx_distributed_inference.models.config import (
FusedSpecNeuronConfig, OnDeviceSamplingConfig)
from neuronx_distributed_inference.models.mllama.utils import (
create_vision_mask)
from neuronx_distributed_inference.modules.lora_serving import (
LoraServingConfig)
from neuronx_distributed_inference.utils.hf_adapter import (
load_pretrained_config)
from transformers import AutoModelForCausalLM, AutoTokenizer, PretrainedConfig
from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import CompletionSequenceGroupOutput, SequenceOutput
# yapf: enable
logger = init_logger(__name__)
TORCH_DTYPE_TO_NEURON_AMP = {
"auto": "float32",
"half": "float16",
"float16": "float16",
"bfloat16": "bfloat16",
"float": "float32",
"float32": "float32",
torch.float16: "float16",
torch.bfloat16: "bfloat16",
torch.float32: "float32",
}
# Models supported by Neuronx distributed for inference.
_NEURON_SUPPORTED_MODELS: dict[str, tuple[str, str]] = {
"LlamaForCausalLM":
("neuronx_distributed_inference.models.llama.modeling_llama",
"NeuronLlamaForCausalLM"),
"MistralForCausalLM":
("neuronx_distributed_inference.models.llama.modeling_llama",
"NeuronLlamaForCausalLM"),
"DbrxForCausalLM":
("neuronx_distributed_inference.models.dbrx.modeling_dbrx",
"NeuronDbrxForCausalLM"),
"MixtralForCausalLM":
("neuronx_distributed_inference.models.mixtral.modeling_mixtral",
"NeuronMixtralForCausalLM"),
"MllamaForConditionalGeneration":
("neuronx_distributed_inference.models.mllama.modeling_mllama",
"NeuronMllamaForCausalLM"),
}
class NeuronCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
) -> None:
super().__init__()
self.config = config
self.logits_processor = LogitsProcessor(config.vocab_size,
logits_as_input=True)
self.sampler = Sampler()
# Lazy initialized
self.model: nn.Module
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_block_ids: torch.Tensor,
sampling_params: torch.Tensor,
prev_hidden: Optional[torch.Tensor] = None,
adapter_ids: Optional[torch.Tensor] = None) -> torch.Tensor:
# sort block ids sequentially for perf/neuron support reasons
sorted_input_block_ids, sorted_indices = torch.sort(input_block_ids)
input_ids = torch.index_select(input_ids, 0, sorted_indices)
positions = torch.index_select(positions, 0, sorted_indices)
sampling_params = torch.index_select(sampling_params, 0,
sorted_indices)
output = self.model(input_ids,
attention_mask=None,
position_ids=positions,
seq_ids=sorted_input_block_ids,
sampling_params=sampling_params,
prev_hidden=prev_hidden,
adapter_ids=adapter_ids)
# on-device sampling
if self.config.neuron_config.on_device_sampling_config:
output = output.hidden_states
else:
output = output.logits[:, -1, :]
restored_indices = torch.argsort(sorted_indices)
if input_block_ids.shape[0] != 1:
output = torch.index_select(output, 0, restored_indices)
return output
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(None, hidden_states, sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
# on-device sampling
if self.config.neuron_config.on_device_sampling_config:
batch_size = logits.shape
seq_ids = [
seq_id for sg in sampling_metadata.seq_groups
for seq_id in sg.seq_ids
]
assert len(seq_ids) == list(batch_size)[0], "batch size mismatch"
# Organize input tensors by step instead of by sequence.
accepted_token_ids_by_step = logits.flatten()
accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
step_output_token_ids = []
for i, seq_id in enumerate(seq_ids):
token_id = accepted_token_ids_by_step[i]
step_output_token_ids.append(
CompletionSequenceGroupOutput(samples=[
SequenceOutput(parent_seq_id=seq_id,
output_token=token_id,
logprobs={token_id: Logprob(token_id)})
],
prompt_logprobs=None))
return SamplerOutput(outputs=step_output_token_ids)
else:
return self.sampler(logits, sampling_metadata)
def load_weights(self, model_name_or_path: str, **kwargs):
arch = _get_model_architecture(self.config)
neuronx_module_path, neuronx_model_cls_name = (
_NEURON_SUPPORTED_MODELS[arch])
neuronx_module = importlib.import_module(neuronx_module_path)
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)
neuron_config = neuronx_model_cls.get_neuron_config_cls()(
**kwargs['neuron_config'])
self.config.neuron_config = neuron_config
config = neuronx_model_cls.get_config_cls()(
neuron_config,
load_config=load_pretrained_config(model_name_or_path))
hashed_config = hashlib.md5(config.to_json_string().encode('utf-8'),
usedforsecurity=False).hexdigest()
if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None:
compiled_model_path = os.getenv("NEURON_COMPILED_ARTIFACTS")
elif os.path.exists(model_name_or_path):
compiled_model_path = os.path.join(model_name_or_path,
"neuron-compiled-artifacts",
hashed_config)
shutil.rmtree(compiled_model_path, ignore_errors=True)
else:
compiled_model_path = os.path.join("local-models",
model_name_or_path,
"neuron-compiled-artifacts",
hashed_config)
shutil.rmtree(compiled_model_path, ignore_errors=True)
try:
self.model = neuronx_model_cls(compiled_model_path)
override_neuron_config = kwargs["override_neuron_config"]
for k, v in override_neuron_config.items():
setattr(self.model.config.neuron_config, k, v)
self.model.load(compiled_model_path)
return
except (FileNotFoundError, ValueError) as e:
logger.warning("Exception: %s", e)
logger.warning("Failed to load the model from %s, Recompiling...",
compiled_model_path)
if not os.path.exists(model_name_or_path):
hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
saved_path = os.path.join("local-models", model_name_or_path)
hf_model.save_pretrained(saved_path)
model_name_or_path = saved_path
self.model = neuronx_model_cls(model_name_or_path, config)
self.model.compile(compiled_model_path)
self.model.load(compiled_model_path)
class NeuronMllamaForCausalLM(nn.Module):
def __init__(self,
config: PretrainedConfig,
on_device_sampling_disabled: bool = False) -> None:
super().__init__()
# has_image is the only multimodal input that is used in
# token-generation
# This is a cache (on CPU) that saves has_image data per sequence id
# The number of entries in this cache is <= Batch-Size
self.has_image_cache: dict[int, torch.Tensor] = {}
self.config = config
self.logits_processor = LogitsProcessor(
config.get_text_config().vocab_size, logits_as_input=True)
self.on_device_sampling_disabled = on_device_sampling_disabled
if self.on_device_sampling_disabled:
# Use default sampler
self.sampler = Sampler()
# Lazy initialized
self.model: nn.Module
self.is_reorder_needed: bool = True
def read_from_has_image_cache(self, seq_ids: torch.Tensor):
has_image_list = []
for index in range(len(seq_ids)):
seq_id = seq_ids[index].item()
if seq_id in self.has_image_cache:
has_image_list.append(self.has_image_cache[seq_id])
else:
has_image_list.append(torch.tensor([0]))
return torch.tensor(has_image_list)
def write_to_has_image_cache(self, seq_ids: torch.Tensor,
has_image: torch.Tensor):
for index in range(len(seq_ids)):
seq_id = seq_ids[index].item()
if index < len(has_image):
self.has_image_cache[seq_id] = has_image[index]
else:
self.has_image_cache[seq_id] = torch.zeros(1)
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
seq_ids: torch.Tensor, pixel_values: torch.Tensor,
aspect_ratios: torch.Tensor, num_chunks: torch.Tensor,
has_image: torch.Tensor, sampling_params) -> torch.Tensor:
# We update the has_image cache during prefill
# and read the has_image cache during decode
if input_ids.shape[-1] > 1: # prefill
self.write_to_has_image_cache(seq_ids, has_image)
else:
has_image = self.read_from_has_image_cache(seq_ids)
bs = input_ids.shape[0]
num_chunks = torch.zeros((bs, 1))
aspect_ratios = torch.zeros((bs, 1, 2))
input_block_ids = seq_ids
origin_input_block_ids = seq_ids
if self.is_reorder_needed:
# sort block ids sequentially for perf/neuron support reasons
input_block_ids, sorted_indices = torch.sort(input_block_ids)
input_ids = torch.index_select(input_ids, 0, sorted_indices)
positions = torch.index_select(positions, 0, sorted_indices)
sampling_params = torch.index_select(sampling_params, 0,
sorted_indices)
pixel_values = torch.index_select(pixel_values, 0, sorted_indices)
aspect_ratios = torch.index_select(aspect_ratios, 0,
sorted_indices)
num_chunks = torch.index_select(num_chunks, 0, sorted_indices)
has_image = torch.index_select(has_image, 0, sorted_indices)
self.vision_mask = create_vision_mask(input_ids, self.vision_token_id)
output = self.model(
input_ids.to(torch.int32),
attention_mask=None,
position_ids=positions.to(torch.int32),
seq_ids=seq_ids.flatten().to(torch.int32),
pixel_values=pixel_values.to(
self.config.vision_config.torch_dtype),
aspect_ratios=aspect_ratios.to(torch.int32),
vision_mask=self.vision_mask.to(torch.int32),
sampling_params=sampling_params,
num_chunks=num_chunks.to(torch.int32),
has_image=has_image.to(torch.int32),
)
if self.config.neuron_config.on_device_sampling_config:
output = output.hidden_states
else:
output = output.logits[:, -1, :]
if self.is_reorder_needed and origin_input_block_ids.shape[0] != 1:
restored_indices = torch.argsort(sorted_indices)
output = torch.index_select(output, 0, restored_indices)
return output
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(None, hidden_states, sampling_metadata)
return logits
def sample(self, hidden_states, sampling_metadata):
if not self.on_device_sampling_disabled:
with torch.profiler.record_function("sample"):
hidden_states = hidden_states.flatten()
res = []
sample_idx = 0
for seq_group in sampling_metadata.seq_groups:
seq_ids = seq_group.seq_ids
samples = []
for seq_id in seq_ids:
token_id = hidden_states[sample_idx].item()
samples.append(
SequenceOutput(
parent_seq_id=seq_id,
output_token=token_id,
logprobs={token_id: Logprob(token_id)}))
sample_idx += 1
res.append(
CompletionSequenceGroupOutput(samples=samples,
prompt_logprobs=None))
next_tokens = SamplerOutput(outputs=res)
else:
next_tokens = self.sampler(None, hidden_states, sampling_metadata)
return next_tokens
def load_weights(self, model_name_or_path: str, **kwargs):
arch = _get_model_architecture(self.config)
neuronx_module_path, neuronx_model_cls_name = (
_NEURON_SUPPORTED_MODELS[arch])
neuronx_module = importlib.import_module(neuronx_module_path)
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)
neuron_config = neuronx_model_cls.get_neuron_config_cls()(
**kwargs['neuron_config'])
self.config.neuron_config = neuron_config
logger.info("neuron_config buckets: %s",
self.config.neuron_config.buckets)
config = neuronx_model_cls.get_config_cls()(
neuron_config,
load_config=load_pretrained_config(model_name_or_path))
hashed_config = hashlib.md5(config.to_json_string().encode('utf-8'),
usedforsecurity=False).hexdigest()
if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None:
compiled_model_path = os.getenv("NEURON_COMPILED_ARTIFACTS")
elif os.path.exists(model_name_or_path):
compiled_model_path = os.path.join(model_name_or_path,
"neuron-compiled-artifacts",
hashed_config)
else:
compiled_model_path = os.path.join("local-models",
model_name_or_path,
"neuron-compiled-artifacts",
hashed_config)
try:
self.model = neuronx_model_cls(compiled_model_path)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.vision_token_id = tokenizer(
"<|image|>", add_special_tokens=False).input_ids[0]
self.model.load(compiled_model_path)
return
except (FileNotFoundError, ValueError):
logger.warning("Failed to load the model from %s, Recompiling...",
compiled_model_path)
if not os.path.exists(model_name_or_path):
hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
saved_path = os.path.join("local-models", model_name_or_path)
hf_model.save_pretrained(saved_path)
model_name_or_path = saved_path
self.model = neuronx_model_cls(model_name_or_path, config)
logger.info("\nCompiling and saving model to %s", model_name_or_path)
p = multiprocessing.Process(target=compile_model,
args=(self, compiled_model_path))
p.start()
p.join()
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
tokenizer.save_pretrained(compiled_model_path)
logger.info("Successfully compiled and saved the model in %s",
compiled_model_path)
# Read "<|image|>" token_id from the tokenizer
self.vision_token_id = tokenizer("<|image|>",
add_special_tokens=False).input_ids[0]
logger.info("\nLoading model from compiled checkpoint...")
self.model.load(compiled_model_path)
def compile_model(neuron_model, traced_model_path):
neuron_model.model.compile(traced_model_path)
class NeuronSpeculationCausalLM(nn.Module):
"""A Neuron-optimized causal language model with speculative decoding."""
def __init__(
self,
config: PretrainedConfig,
) -> None:
super().__init__()
self.config = config
self.logits_processor = LogitsProcessor(config.vocab_size,
logits_as_input=True)
# Lazy initialized
self.model: nn.Module
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_block_ids: torch.Tensor,
sampling_params: torch.Tensor,
) -> torch.Tensor:
# sort block ids sequentially for perf/neuron support reasons
sorted_input_block_ids, sorted_indices = torch.sort(input_block_ids)
input_ids = torch.index_select(input_ids, 0, sorted_indices)
positions = torch.index_select(positions, 0, sorted_indices)
sampling_params = torch.index_select(sampling_params, 0,
sorted_indices)
output = self.model(input_ids,
attention_mask=None,
position_ids=positions,
seq_ids=sorted_input_block_ids,
sampling_params=sampling_params)
restored_indices = torch.argsort(sorted_indices)
# CTX encoding
if (positions[:, 0]).sum().item() == 0:
output = output.fused_outputs[0][:, 0:1]
if input_block_ids.shape[0] != 1:
output = torch.index_select(output, 0, restored_indices)
return output
# Fused Spec (Generation)
accepted_tokens_with_padding = output.fused_outputs[0]
next_pos_ids = output.fused_outputs[-1]
generated_token_counts = next_pos_ids - positions
assert torch.any(generated_token_counts == 0).item() is False, \
"NxDI model generated no output for one or more sequences."
batch_size, steps = accepted_tokens_with_padding.shape
mask = torch.arange(steps).expand(batch_size,
-1) >= generated_token_counts
accepted_tokens_with_padding[mask] = -1
if input_block_ids.shape[0] != 1:
accepted_tokens_with_padding = torch.index_select(
accepted_tokens_with_padding, 0, restored_indices)
return accepted_tokens_with_padding
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[list[SamplerOutput]]:
batch_size, num_steps = logits.shape
seq_ids = [
seq_id for sg in sampling_metadata.seq_groups
for seq_id in sg.seq_ids
]
# Organize input tensors by step instead of by sequence.
accepted_token_ids_by_step = logits.transpose(0, 1)
accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
sampler_output_list = []
for step_index in range(num_steps):
if all(token_id == -1
for token_id in accepted_token_ids_by_step[step_index]):
break
step_output_token_ids = []
for sequence_index in range(batch_size):
token_id = accepted_token_ids_by_step[step_index][
sequence_index]
step_output_token_ids.append(
CompletionSequenceGroupOutput(samples=[
SequenceOutput(parent_seq_id=seq_ids[sequence_index],
output_token=token_id,
logprobs={token_id: Logprob(token_id)})
],
prompt_logprobs=None))
sampler_output_list.append(
SamplerOutput(outputs=step_output_token_ids))
return sampler_output_list
def load_weights(self, model_name_or_path: str,
draft_model_name_or_path: str, **kwargs):
arch = _get_model_architecture(self.config)
neuronx_module_path, neuronx_model_cls_name = (
_NEURON_SUPPORTED_MODELS[arch])
neuronx_module = importlib.import_module(neuronx_module_path)
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)
neuron_config = neuronx_model_cls.get_neuron_config_cls()(
**kwargs['neuron_config'])
config = neuronx_model_cls.get_config_cls()(
neuron_config,
load_config=load_pretrained_config(model_name_or_path))
draft_neuron_config = copy.deepcopy(config.neuron_config)
if not config.neuron_config.enable_eagle_speculation:
draft_neuron_config.speculation_length = 0
draft_neuron_config.trace_tokengen_model = True
draft_neuron_config.enable_fused_speculation = False
if getattr(config.neuron_config, "draft_model_modules_to_not_convert",
None):
draft_neuron_config.modules_to_not_convert = (
draft_neuron_config.draft_model_modules_to_not_convert)
if config.neuron_config.enable_eagle_speculation:
draft_neuron_config.is_eagle_draft = True
draft_neuron_config.sequence_parallel_enabled = False
draft_config = neuronx_model_cls.get_config_cls()(
draft_neuron_config,
load_config=load_pretrained_config(draft_model_name_or_path))
fused_spec_config = (FusedSpecNeuronConfig(
neuronx_model_cls._model_cls,
draft_config=draft_config,
draft_model_path=draft_model_name_or_path))
config.fused_spec_config = fused_spec_config
self.config.neuron_config = neuron_config
hashed_config = hashlib.md5(config.to_json_string().encode('utf-8'),
usedforsecurity=False).hexdigest()
if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None:
compiled_model_path = os.getenv("NEURON_COMPILED_ARTIFACTS")
elif os.path.exists(model_name_or_path):
compiled_model_path = os.path.join(model_name_or_path,
"neuron-compiled-artifacts",
hashed_config)
shutil.rmtree(compiled_model_path, ignore_errors=True)
else:
compiled_model_path = os.path.join("local-models",
model_name_or_path,
"neuron-compiled-artifacts",
hashed_config)
shutil.rmtree(compiled_model_path, ignore_errors=True)
try:
self.model = neuronx_model_cls(compiled_model_path)
override_neuron_config = kwargs["override_neuron_config"]
for k, v in override_neuron_config.items():
setattr(self.model.config.neuron_config, k, v)
self.model.load(compiled_model_path)
return
except (FileNotFoundError, ValueError) as e:
logger.warning("Exception: %s", e)
logger.warning("Failed to load the model from %s Recompiling...",
compiled_model_path)
if not os.path.exists(model_name_or_path):
hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
saved_path = os.path.join("local-models", model_name_or_path)
hf_model.save_pretrained(saved_path)
model_name_or_path = saved_path
if not os.path.exists(draft_model_name_or_path):
if draft_model_name_or_path != model_name_or_path:
hf_model = AutoModelForCausalLM.from_pretrained(
draft_model_name_or_path)
saved_path = os.path.join("local-models",
draft_model_name_or_path)
hf_model.save_pretrained(saved_path)
draft_model_name_or_path = saved_path
else:
draft_model_name_or_path = model_name_or_path
config.fused_spec_config.draft_model_path = draft_model_name_or_path
self.model = neuronx_model_cls(model_name_or_path, config)
self.model.compile(compiled_model_path)
self.model.load(compiled_model_path)
def _get_model_architecture(config: PretrainedConfig) -> str:
architectures = getattr(config, "architectures", [])
for arch in architectures:
if arch in _NEURON_SUPPORTED_MODELS:
return arch
raise ValueError(
f"Model architectures {architectures} are not supported on Neuron "
f"for now. Supported architectures: "
f"{list(_NEURON_SUPPORTED_MODELS.keys())}")
def _get_default_neuron_config(model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
lora_serving_config: LoraServingConfig):
"""Generate a neuron config based on vllm config args."""
on_device_sampling_config = OnDeviceSamplingConfig(dynamic=True,
deterministic=False)
batch_size = scheduler_config.max_num_seqs
neuron_config = dict(
tp_degree=parallel_config.tensor_parallel_size,
ctx_batch_size=1,
batch_size=batch_size,
max_context_length=scheduler_config.max_model_len,
seq_len=scheduler_config.max_model_len,
enable_bucketing=True,
is_continuous_batching=True,
quantized=False,
torch_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
padding_side="right",
on_device_sampling_config=on_device_sampling_config,
sequence_parallel_enabled=True,
lora_serving_config=lora_serving_config)
return neuron_config
def _get_default_speculation_config(model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
speculation_config: SpeculativeConfig):
"""Generate a neuron config for speculative decoding based on vllm config
args."""
neuron_config = dict(
tp_degree=parallel_config.tensor_parallel_size,
ctx_batch_size=1,
batch_size=scheduler_config.max_num_seqs,
max_context_length=scheduler_config.max_model_len,
seq_len=scheduler_config.max_model_len,
speculation_length=speculation_config.num_speculative_tokens,
trace_tokengen_model=False,
enable_fused_speculation=True,
enable_bucketing=True,
is_continuous_batching=True,
quantized=False,
torch_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
on_device_sampling_config=dict(
top_k=1,
do_sample=False,
))
return neuron_config
def _get_neuron_config_after_override(default_neuron_config,
overridden_neuron_config):
"""Update default neuron config values with override args"""
overridden_neuron_config = overridden_neuron_config or {}
default_neuron_config.update(overridden_neuron_config)
return default_neuron_config
def get_neuron_model(model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
lora_serving_config: LoraServingConfig) -> nn.Module:
"""Initializes a neuron-optimized model for inference."""
model_arch = _get_model_architecture(model_config.hf_config)
if model_arch == "MllamaForConditionalGeneration":
model = NeuronMllamaForCausalLM(model_config.hf_config)
else:
model = NeuronCausalLM(model_config.hf_config)
default_neuron_config_args = _get_default_neuron_config(
model_config, parallel_config, scheduler_config, lora_serving_config)
neuron_config = _get_neuron_config_after_override(
default_neuron_config_args, model_config.override_neuron_config)
override_neuron_config = model_config.override_neuron_config
model.load_weights(model_config.model,
neuron_config=neuron_config,
override_neuron_config=override_neuron_config)
return model.eval()
def get_neuron_speculation_model(model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
speculation_config: SpeculativeConfig):
"""Initializes a neuron-optimized speculation model for inference.
This model handles speculation using both a draft model and an EAGLE draft.
"""
model = NeuronSpeculationCausalLM(model_config.hf_config)
default_neuron_config_args = _get_default_speculation_config(
model_config, parallel_config, scheduler_config, speculation_config)
neuron_config = _get_neuron_config_after_override(
default_neuron_config_args, model_config.override_neuron_config)
override_neuron_config = model_config.override_neuron_config
model.load_weights(model_config.model,
speculation_config.draft_model_config.model,
neuron_config=neuron_config,
override_neuron_config=override_neuron_config)
return model.eval()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment