"tests/kernels/mamba/test_causal_conv1d.py" did not exist on "78029b34ed1be46baf06f92c9e971ea1961d0867"
Commit 6d2051cc authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.3.post1' into v0.6.3.post1-dev

parents 2c7f740a a2c71c54
......@@ -15,6 +15,11 @@ CHAT_TEMPLATE = "Dummy chat template for testing {}"
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
@dataclass
class MockHFConfig:
model_type: str = "any"
@dataclass
class MockModelConfig:
tokenizer = MODEL_NAME
......@@ -24,6 +29,7 @@ class MockModelConfig:
tokenizer_revision = None
embedding_mode = False
multimodal_config = MultiModalConfig()
hf_config = MockHFConfig()
@dataclass
......
......@@ -104,28 +104,40 @@ async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str,
"role": "user",
"content": "Can I ask a question? vllm1"
}]
prompt = tokenizer.apply_chat_template(
add_generation_prompt=add_generation,
conversation=conversation,
tokenize=False)
tokens = tokenizer.encode(prompt, add_special_tokens=add_special)
response = requests.post(base_url + "/tokenize",
json={
"add_generation_prompt":
add_generation,
"add_special_tokens": add_special,
"messages": conversation,
"model": model_name
})
response.raise_for_status()
assert response.json() == {
"tokens": tokens,
"count": len(tokens),
"max_model_len": 8192
}
for continue_final in [False, True]:
if add_generation and continue_final:
continue
if continue_final:
conversation.append({
"role": "assistant",
"content": "Sure,"
})
prompt = tokenizer.apply_chat_template(
add_generation_prompt=add_generation,
continue_final_message=continue_final,
conversation=conversation,
tokenize=False)
tokens = tokenizer.encode(prompt,
add_special_tokens=add_special)
response = requests.post(base_url + "/tokenize",
json={
"add_generation_prompt":
add_generation,
"continue_final_message":
continue_final,
"add_special_tokens": add_special,
"messages": conversation,
"model": model_name
})
response.raise_for_status()
assert response.json() == {
"tokens": tokens,
"count": len(tokens),
"max_model_len": 8192
}
@pytest.mark.asyncio
......
......@@ -23,9 +23,16 @@ TEST_IMAGE_URLS = [
@pytest.fixture(scope="module")
def server():
args = [
"--dtype", "bfloat16", "--max-model-len", "4096", "--max-num-seqs",
"5", "--enforce-eager", "--trust-remote-code", "--limit-mm-per-prompt",
f"image={MAXIMUM_IMAGES}"
"--dtype",
"bfloat16",
"--max-model-len",
"2048",
"--max-num-seqs",
"5",
"--enforce-eager",
"--trust-remote-code",
"--limit-mm-per-prompt",
f"image={MAXIMUM_IMAGES}",
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
......
......@@ -20,22 +20,22 @@ def test_env(name: str, device: str, monkeypatch):
if device == "cpu":
with patch("vllm.attention.selector.is_cpu", return_value=True):
backend = which_attn_to_use(8, 16, 8, None, torch.float16,
torch.float16, 16)
backend = which_attn_to_use(16, None, torch.float16, torch.float16,
16, False)
assert backend.name == "TORCH_SDPA"
elif device == "hip":
with patch("vllm.attention.selector.is_hip", return_value=True):
backend = which_attn_to_use(8, 16, 8, None, torch.float16,
torch.float16, 16)
backend = which_attn_to_use(16, None, torch.float16, torch.float16,
16, False)
assert backend.name == "ROCM_FLASH"
elif device == "openvino":
with patch("vllm.attention.selector.is_openvino", return_value=True):
backend = which_attn_to_use(8, 16, 8, None, torch.float16,
torch.float16, 16)
backend = which_attn_to_use(16, None, torch.float16, torch.float16,
16, False)
assert backend.name == "OPENVINO"
else:
backend = which_attn_to_use(8, 16, 8, None, torch.float16,
torch.float16, 16)
backend = which_attn_to_use(16, None, torch.float16, torch.float16, 16,
False)
assert backend.name == name
......@@ -46,32 +46,37 @@ def test_flash_attn(monkeypatch):
# Unsupported CUDA arch
with patch("torch.cuda.get_device_capability", return_value=(7, 5)):
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
backend = which_attn_to_use(16, None, torch.float16, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported data type
backend = which_attn_to_use(8, 16, 8, None, torch.float8_e4m3fn, None, 16)
backend = which_attn_to_use(16, None, torch.float8_e4m3fn, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported kv cache data type
backend = which_attn_to_use(8, 16, 8, None, torch.float16, "fp8", 16)
backend = which_attn_to_use(16, None, torch.float16, "fp8", 16, False)
assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported block size
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 8)
backend = which_attn_to_use(16, None, torch.float16, None, 8, False)
assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported sliding window
backend = which_attn_to_use(8, 16, 8, 1, torch.float16, None, 16)
backend = which_attn_to_use(16, 1, torch.float16, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL
# flash-attn is not installed
with patch.dict('sys.modules', {'vllm_flash_attn': None}):
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
backend = which_attn_to_use(16, None, torch.float16, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported head size
backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16)
backend = which_attn_to_use(17, None, torch.float16, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL
# Attention-free models should bypass env and use PlaceholderAttention
backend = which_attn_to_use(16, None, torch.float16, torch.float16, 16,
True)
assert backend.name != STR_FLASH_ATTN_VAL
......@@ -79,4 +84,4 @@ def test_invalid_env(monkeypatch):
"""Throw an exception if the backend name is invalid."""
override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
with pytest.raises(ValueError):
which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
which_attn_to_use(16, None, torch.float16, None, 16, False)
import os
import pytest
import torch
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401
@pytest.mark.skipif(not hasattr(torch.ops._C, "awq_dequantize"),
reason="AWQ is not supported on this GPU type.")
def test_awq_dequantize_opcheck():
os.environ["VLLM_USE_TRITON_AWQ"] = "0"
qweight = torch.randint(-2000000000,
......@@ -21,6 +24,8 @@ def test_awq_dequantize_opcheck():
(qweight, scales, zeros, split_k_iters, thx, thy))
@pytest.mark.skipif(not hasattr(torch.ops._C, "awq_gemm"),
reason="AWQ is not supported on this GPU type.")
def test_awq_gemm_opcheck():
os.environ["VLLM_USE_TRITON_AWQ"] = "0"
input = torch.rand((2, 8192), device='cuda', dtype=torch.float16)
......
"""Test AWQ with fused MoE Marlin kernels.
Run `pytest tests/kernels/test_awq_marlin.py`.
"""
import pytest
import torch
from tests.kernels.utils import (compute_max_diff, stack_and_dev, torch_moe,
torch_moe_single)
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe, single_marlin_moe)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
awq_marlin_quantize)
from vllm.scalar_type import scalar_types
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 1024, 512])
@pytest.mark.parametrize("e", [8, 64])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.skipif(not (ops.supports_moe_ops
and hasattr(torch.ops._moe_C, "marlin_gemm_moe")),
reason="Marlin is not supported on this GPU type.")
def test_fused_marlin_moe_awq(
m: int,
n: int,
k: int,
e: int,
topk: int,
group_size: int,
):
torch.manual_seed(7)
num_bits = 4
quant_type = scalar_types.uint4
dtype = torch.float16
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
w_ref1_l = []
qweights1_l = []
scales1_l = []
zp1_l = []
for i in range(w1.shape[0]):
w_ref1, qweight1, scales1, zp1 = awq_marlin_quantize(
w1[i].transpose(1, 0), quant_type, group_size)
w_ref1_l.append(w_ref1)
qweights1_l.append(qweight1)
scales1_l.append(scales1)
zp1_l.append(zp1)
w_ref1 = stack_and_dev(w_ref1_l)
qweight1 = stack_and_dev(qweights1_l).contiguous()
scales1 = stack_and_dev(scales1_l)
zp1 = stack_and_dev(zp1_l)
w_ref2_l = []
qweights2_l = []
scales2_l = []
zp2_l = []
for i in range(w2.shape[0]):
w_ref2, qweight2, scales2, zp2 = awq_marlin_quantize(
w2[i].transpose(1, 0), quant_type, group_size)
w_ref2_l.append(w_ref2)
qweights2_l.append(qweight2)
scales2_l.append(scales2)
zp2_l.append(zp2)
w_ref2 = stack_and_dev(w_ref2_l)
qweight2 = stack_and_dev(qweights2_l).contiguous()
scales2 = stack_and_dev(scales2_l)
zp2 = stack_and_dev(zp2_l)
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids = fused_topk(a, score, topk, False)
marlin_output = fused_marlin_moe(
a,
qweight1,
qweight2,
scales1,
scales2,
score,
topk_weights,
topk_ids,
w1_zeros=zp1,
w2_zeros=zp2,
num_bits=num_bits,
)
torch_output = torch_moe(
a,
w_ref1.transpose(1, 2),
w_ref2.transpose(1, 2),
score,
topk,
)
assert compute_max_diff(marlin_output, torch_output) < 4e-2
@pytest.mark.skip("This test is here for the sake of debugging, "
"don't run it in automated tests.")
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 1024, 512])
@pytest.mark.parametrize("e", [8, 64])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
def test_single_marlin_moe_multiply_awq(
m: int,
n: int,
k: int,
e: int,
topk: int,
group_size: int,
):
torch.manual_seed(7)
num_bits = 4
quant_type = scalar_types.uint4
dtype = torch.float16
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
w_ref_l = []
qweights_l = []
scales_l = []
zp_l = []
for i in range(w.shape[0]):
w_ref, qweight, scales, zp = awq_marlin_quantize(
w[i].transpose(1, 0), quant_type, group_size)
w_ref_l.append(w_ref)
qweights_l.append(qweight)
scales_l.append(scales)
zp_l.append(zp)
w_ref = stack_and_dev(w_ref_l)
qweight = stack_and_dev(qweights_l).contiguous()
scales = stack_and_dev(scales_l).contiguous()
zp = stack_and_dev(zp_l).contiguous()
score = torch.randn((m, e), device="cuda", dtype=dtype)
marlin_output = single_marlin_moe(a,
qweight,
scales,
score,
topk,
renormalize=False,
w_zeros=zp,
num_bits=num_bits)
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)
assert compute_max_diff(marlin_output, torch_output) < 1e-2
......@@ -3,10 +3,10 @@ from typing import Optional
import pytest
import torch
import torch.nn.functional as F
from einops import rearrange
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.utils import seed_everything
......@@ -57,45 +57,73 @@ def causal_conv1d_ref(
return (out, None) if not return_final_states else (out, final_states_out)
def causal_conv1d_update_ref(x: torch.Tensor,
conv_state: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
activation: Optional[str] = None):
def causal_conv1d_update_ref(x,
conv_state,
weight,
bias=None,
activation=None,
cache_seqlens=None):
"""
x: (batch, dim)
conv_state: (batch, dim, width)
x: (batch, dim) or (batch, dim, seqlen)
conv_state: (batch, dim, state_len), where state_len >= width - 1
weight: (dim, width)
bias: (dim,)
cache_seqlens: (batch,), dtype int32.
If not None, the conv_state is treated as a circular buffer.
The conv_state will be updated by copying x to the
conv_state starting at the index
@cache_seqlens % state_len before performing the convolution.
out: (batch, dim)
out: (batch, dim) or (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
dtype_in = x.dtype
batch, dim = x.shape
unsqueeze = x.dim() == 2
if unsqueeze:
x = x.unsqueeze(-1)
batch, dim, seqlen = x.shape
width = weight.shape[1]
assert conv_state.shape == (batch, dim, width)
state_len = conv_state.shape[-1]
assert conv_state.shape == (batch, dim, state_len)
assert weight.shape == (dim, width)
conv_state.copy_(torch.roll(conv_state, shifts=-1,
dims=-1)) # Update state (B D W)
conv_state[:, :, -1] = x
out = torch.sum(conv_state * weight, dim=-1) # (B D)
if bias is not None:
out += bias
if cache_seqlens is None:
x_new = torch.cat([conv_state, x], dim=-1).to(
weight.dtype) # (batch, dim, state_len + seqlen)
conv_state.copy_(x_new[:, :, -state_len:])
else:
width_idx = torch.arange(
-(width - 1), 0, dtype=torch.long,
device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(
-1, dim, -1)
x_new = torch.cat([conv_state.gather(2, width_idx), x],
dim=-1).to(weight.dtype)
copy_idx = torch.arange(
seqlen, dtype=torch.long,
device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
copy_idx = torch.remainder(copy_idx,
state_len).unsqueeze(1).expand(-1, dim, -1)
conv_state.scatter_(2, copy_idx, x)
out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0,
groups=dim)[:, :, -seqlen:]
if unsqueeze:
out = out.squeeze(-1)
return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
def causal_conv1d_opcheck_fn(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
seq_idx: Optional[torch.Tensor] = None,
initial_states: Optional[torch.Tensor] = None,
return_final_states: bool = False,
final_states_out=None,
activation: Optional[str] = "silu",
):
@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float])
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True])
def causal_conv1d_opcheck_fn(x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
cu_seq_len: Optional[torch.Tensor] = None,
cache_indices: Optional[torch.Tensor] = None,
has_initial_state: Optional[torch.Tensor] = None,
conv_states: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu",
pad_slot_id: int = PAD_SLOT_ID):
"""
x: (batch, dim, seqlen)
weight: (dim, width)
......@@ -109,135 +137,86 @@ def causal_conv1d_opcheck_fn(
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
if x.stride(2) != 1 and x.stride(1) != 1:
if x.stride(-1) != 1:
x = x.contiguous()
bias = bias.contiguous() if bias is not None else None
if seq_idx is not None:
assert (initial_states is
None), "initial_states must be None if seq_idx is not None"
assert (not return_final_states
), "If seq_idx is not None, we don't return final_states_out"
seq_idx = seq_idx.contiguous() if seq_idx is not None else None
if initial_states is not None and (initial_states.stride(2) != 1
and initial_states.stride(1) != 1):
initial_states = initial_states.contiguous()
if return_final_states:
assert (
x.stride(1) == 1
), "Only channel-last layout support returning final_states_out"
if final_states_out is not None:
assert (final_states_out.stride(2) == 1
or final_states_out.stride(1) == 1)
else:
batch, dim, seqlen = x.shape
width = weight.shape[1]
final_states_out = torch.empty(batch,
width - 1,
dim,
device=x.device,
dtype=x.dtype).transpose(1, 2)
else:
final_states_out = None
opcheck(torch.ops._C.causal_conv1d_fwd,
(x, weight, bias, seq_idx, initial_states, final_states_out,
activation in ["silu", "swish"]))
(x, weight, bias, conv_states, cu_seq_len, cache_indices,
has_initial_state, activation in ["silu", "swish"], pad_slot_id))
@pytest.mark.parametrize("return_final_states", [False, True])
@pytest.mark.parametrize("has_initial_states", [False, True])
@pytest.mark.parametrize("channel_last", [False, True])
@pytest.mark.parametrize("itype", [torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [False, True])
@pytest.mark.parametrize("has_bias", [False, True])
@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float])
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True])
@pytest.mark.parametrize("width", [4])
@pytest.mark.parametrize("seqlen", [128, 512, 4096])
@pytest.mark.parametrize('dim', [64, 4096 + 32])
@pytest.mark.parametrize('batch', [1, 2])
@pytest.mark.parametrize(
'seqlen', [1, 8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
@pytest.mark.parametrize('dim', [64])
@pytest.mark.parametrize('batch', [1])
def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
itype, channel_last, has_initial_states,
return_final_states):
if not channel_last and (has_initial_states or return_final_states):
pytest.skip(
"Only channel_last support initial_states or return_final_states")
itype):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
# set seed
seed_everything(0)
if not channel_last:
x = torch.randn(batch,
4096 + dim + 64,
seqlen,
device=device,
dtype=itype)[:, 4096:4096 + dim, :]
else:
x = rearrange(
torch.randn(batch,
seqlen,
4096 + dim + 64,
device=device,
dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s")
x = torch.randn(batch, dim, seqlen, device=device,
dtype=itype).contiguous()
weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
if has_initial_states:
initial_states = torch.randn(batch,
width - 1,
dim,
device=device,
dtype=itype).transpose(1, 2)
else:
initial_states = None
x_ref = x.detach().clone()
weight_ref = weight.detach().clone()
bias_ref = bias.detach().clone() if bias is not None else None
initial_states_ref = initial_states.detach().clone(
initial_states = torch.randn(batch,
dim,
width - 1,
device=device,
dtype=itype)
x_ref = x.clone()
weight_ref = weight.clone()
bias_ref = bias.clone() if bias is not None else None
initial_states_ref = initial_states.clone(
) if initial_states is not None else None
activation = None if not silu_activation else "silu"
out, final_states = causal_conv1d_fn(
x,
weight,
bias,
initial_states=initial_states,
return_final_states=return_final_states,
activation=activation)
out = causal_conv1d_fn(x,
weight,
bias,
activation=activation,
conv_states=initial_states,
has_initial_state=torch.ones(batch,
dtype=torch.bool,
device=x.device))
out_ref, final_states_ref = causal_conv1d_ref(
x_ref,
weight_ref,
bias_ref,
initial_states=initial_states_ref,
return_final_states=return_final_states,
return_final_states=True,
activation=activation)
causal_conv1d_opcheck_fn(x_ref,
weight_ref,
bias_ref,
initial_states=initial_states_ref,
return_final_states=return_final_states,
activation=activation)
if return_final_states:
assert final_states is not None and final_states_ref is not None
assert torch.allclose(final_states,
final_states_ref,
rtol=rtol,
atol=atol)
assert initial_states is not None and final_states_ref is not None
assert torch.allclose(initial_states,
final_states_ref,
rtol=rtol,
atol=atol)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
if return_final_states:
out += F.sigmoid(final_states).sum(dim=-1, keepdim=True)
out_ref += F.sigmoid(final_states_ref).sum(dim=-1, keepdim=True)
causal_conv1d_opcheck_fn(x,
weight,
bias,
activation=activation,
conv_states=initial_states,
has_initial_state=torch.ones(batch,
dtype=torch.bool,
device=x.device))
@pytest.mark.parametrize("itype", [torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [False, True])
@pytest.mark.parametrize("has_bias", [False, True])
@pytest.mark.parametrize("width", [2, 3, 4])
@pytest.mark.parametrize("seqlen", [1])
@pytest.mark.parametrize("width", [4])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
@pytest.mark.parametrize("batch", [1, 2])
def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation,
def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation,
itype):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
......@@ -246,17 +225,12 @@ def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation,
# set seed
seed_everything(0)
batch = 2
x = torch.randn(batch, dim, device=device, dtype=itype)
conv_state = torch.randn(batch, dim, width, device=device, dtype=itype)
weight = torch.randn(dim,
width,
device=device,
dtype=itype,
requires_grad=True)
if has_bias:
bias = torch.randn(dim, device=device, dtype=itype, requires_grad=True)
else:
bias = None
x = torch.randn(batch, dim, seqlen, device=device, dtype=itype)
x_ref = x.clone()
conv_state = torch.randn(batch, dim, width - 1, device=device, dtype=itype)
weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
conv_state_ref = conv_state.detach().clone()
activation = None if not silu_activation else "silu"
out = causal_conv1d_update(x,
......@@ -264,7 +238,7 @@ def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation,
weight,
bias,
activation=activation)
out_ref = causal_conv1d_update_ref(x,
out_ref = causal_conv1d_update_ref(x_ref,
conv_state_ref,
weight,
bias,
......@@ -273,9 +247,9 @@ def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation,
assert torch.equal(conv_state, conv_state_ref)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
opcheck(
torch.ops._C.causal_conv1d_update,
(x, conv_state, weight, bias, activation in ["silu", "swish"], None))
opcheck(torch.ops._C.causal_conv1d_update,
(x, conv_state, weight, bias, activation
in ["silu", "swish"], None, None, PAD_SLOT_ID))
@pytest.mark.parametrize("itype",
......@@ -285,7 +259,10 @@ def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation,
@pytest.mark.parametrize("seqlen", [1, 4, 5])
@pytest.mark.parametrize("width", [2, 3, 4])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_bias,
# tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize("with_padding", [True, False])
def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width,
seqlen, has_bias,
silu_activation, itype):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
......@@ -293,29 +270,37 @@ def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_bias,
rtol, atol = 1e-2, 5e-2
# set seed
torch.random.manual_seed(0)
batch = 64
seed_everything(0)
x = torch.randn(batch, dim, device=device, dtype=itype)
batch_size = 3
padding = 5 if with_padding else 0
padded_batch_size = batch_size + padding
total_entries = 10 * batch_size
total_entries = 10 * batch
x = torch.randn(padded_batch_size, dim, 1, device=device, dtype=itype)
x_ref = x.clone()
conv_state_indices = torch.randperm(total_entries)[:batch_size].to(
dtype=torch.int32, device=device)
unused_states_bool = torch.ones(total_entries,
dtype=torch.bool,
device=device)
unused_states_bool[conv_state_indices] = False
padded_state_indices = torch.concat([
conv_state_indices,
torch.as_tensor(
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device)
],
dim=0)
conv_state = torch.randn(total_entries,
dim,
width,
width - 1,
device=device,
dtype=itype)
conv_state_indices = torch.randperm(total_entries)[:batch].to(
dtype=torch.int32, device=device)
conv_state_for_padding_test = conv_state.clone()
weight = torch.randn(dim,
width,
device=device,
dtype=itype,
requires_grad=True)
if has_bias:
bias = torch.randn(dim, device=device, dtype=itype, requires_grad=True)
else:
bias = None
weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
conv_state_ref = conv_state[conv_state_indices, :].detach().clone()
activation = None if not silu_activation else "silu"
out = causal_conv1d_update(x,
......@@ -323,12 +308,120 @@ def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_bias,
weight,
bias,
activation=activation,
conv_state_indices=conv_state_indices)
out_ref = causal_conv1d_update_ref(x,
conv_state_indices=padded_state_indices,
pad_slot_id=PAD_SLOT_ID)
out_ref = causal_conv1d_update_ref(x_ref[:batch_size],
conv_state_ref,
weight,
bias,
activation=activation)
assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
assert torch.equal(conv_state[unused_states_bool],
conv_state_for_padding_test[unused_states_bool])
opcheck(torch.ops._C.causal_conv1d_update,
(x, conv_state, weight, bias, activation
in ["silu", "swish"], None, padded_state_indices, PAD_SLOT_ID))
@pytest.mark.parametrize("itype", [torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True])
@pytest.mark.parametrize("width", [4])
@pytest.mark.parametrize(
'seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 2049, 4096])
@pytest.mark.parametrize('dim', [64, 4096])
# tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize('with_padding', [True, False])
def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
silu_activation, itype):
device = "cuda"
torch.cuda.empty_cache()
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
# set seed
seed_everything(0)
seqlens = []
batch_size = 4
if seqlen < 10:
batch_size = 1
padding = 3 if with_padding else 0
padded_batch_size = batch_size + padding
nsplits = padded_batch_size - 1
eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
seqlens.append(
torch.diff(
torch.cat(
[torch.tensor([-1]), eos_pos,
torch.tensor([seqlen - 1])])).tolist())
assert sum(seqlens[-1]) == seqlen
assert all(s > 0 for s in seqlens[-1])
total_entries = batch_size * 10
cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum],
dim=0)
x = torch.randn(1, 4096 + dim + 64, seqlen, device=device,
dtype=itype)[:, 4096:4096 + dim, :]
weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
x_ref = x.clone()
weight_ref = weight.clone()
bias_ref = bias.clone() if bias is not None else None
activation = None if not silu_activation else "silu"
final_states = torch.randn(total_entries,
dim,
width - 1,
device=x.device,
dtype=x.dtype)
final_states_ref = final_states.clone()
has_initial_states = torch.randint(0,
2, (cumsum.shape[0] - 1, ),
dtype=torch.bool,
device=x.device)
state_indices = torch.randperm(total_entries,
dtype=torch.int32,
device=x.device)[:batch_size]
padded_state_indices = torch.concat([
state_indices,
torch.as_tensor(
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
],
dim=-1)
out = causal_conv1d_fn(x.squeeze(0), weight, bias, cumsum.cuda(),
padded_state_indices, has_initial_states,
final_states, activation, PAD_SLOT_ID)
out_ref = []
out_ref_b = []
splits = [torch.split(var, seqlens[0], dim=-1) for var in (x_ref)]
for i in range(len(seqlens[0])):
x_s = [v[i].unsqueeze(0) for v in splits][0]
if padded_state_indices[i] == PAD_SLOT_ID:
continue
out_ref_b.append(
causal_conv1d_ref(
x_s,
weight_ref,
bias_ref,
activation=activation,
return_final_states=True,
final_states_out=final_states_ref[
padded_state_indices[i]].unsqueeze(0),
initial_states=final_states_ref[padded_state_indices[i]].
unsqueeze(0) if has_initial_states[i] else None))
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2))
out_ref_tensor = torch.cat(out_ref, dim=0)
unpadded_out = out[:, :out_ref_tensor.shape[-1]]
assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol)
assert torch.allclose(final_states, final_states_ref, rtol=rtol, atol=atol)
causal_conv1d_opcheck_fn(x.squeeze(0), weight, bias, cumsum.cuda(),
padded_state_indices, has_initial_states,
final_states, activation)
......@@ -136,7 +136,9 @@ def _make_test_resources(test_pt: TestPoint, ) -> TestResources:
)
if test_pt.num_blocks is None or test_pt.num_heads is None:
# Caller does not require a KV cache
return TestResources(scale, attn_backend, attn, None)
return TestResources(
scale, attn_backend, attn,
torch.tensor([], dtype=torch.float32, device=CUDA_DEVICE))
# Construct KV cache
kv_cache = make_kv_cache(test_pt.num_blocks,
......@@ -620,7 +622,9 @@ def _run_encoder_attention_test(
return attn.forward(packed_qkv.query,
packed_qkv.key,
packed_qkv.value,
None,
torch.tensor([],
dtype=torch.float32,
device=packed_qkv.query.device),
attn_metadata,
attn_type=attn_type)
......
......@@ -3,9 +3,9 @@ from typing import List, Optional, Tuple
import pytest
import torch
import vllm.attention.backends.flash_attn # noqa: F401
from tests.kernels.utils import opcheck
from vllm.utils import seed_everything
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
flash_attn_with_kvcache)
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
HEAD_SIZES = [128, 256]
......@@ -112,10 +112,10 @@ def test_flash_attn_with_paged_kv(
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)
output = torch.ops.vllm.flash_attn_with_kvcache(
decode_query=query.unsqueeze(1),
key_cache=key_cache,
value_cache=value_cache,
output = flash_attn_with_kvcache(
q=query.unsqueeze(1),
k_cache=key_cache,
v_cache=value_cache,
softmax_scale=scale,
causal=True,
block_table=block_tables,
......@@ -123,25 +123,6 @@ def test_flash_attn_with_paged_kv(
softcap=soft_cap if soft_cap is not None else 0,
).squeeze(1)
if num_blocks <= 2048:
test_utils = ["test_faketensor", "test_schema"]
else:
test_utils = ["test_faketensor"]
opcheck(torch.ops.vllm.flash_attn_with_kvcache,
args=tuple(),
kwargs=dict(
decode_query=query.unsqueeze(1),
key_cache=key_cache,
value_cache=value_cache,
softmax_scale=scale,
causal=True,
block_table=block_tables,
cache_seqlens=kv_lens_tensor,
softcap=soft_cap if soft_cap is not None else 0,
),
test_utils=test_utils)
ref_output = ref_paged_attn(
query=query,
key_cache=key_cache,
......@@ -213,7 +194,7 @@ def test_varlen_with_paged_kv(
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)
output = torch.ops.vllm.flash_attn_varlen_func(
output = flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
......@@ -228,29 +209,6 @@ def test_varlen_with_paged_kv(
softcap=soft_cap if soft_cap is not None else 0,
)
if num_blocks <= 2048:
test_utils = ["test_faketensor", "test_schema"]
else:
test_utils = ["test_faketensor"]
opcheck(torch.ops.vllm.flash_attn_varlen_func,
args=tuple(),
kwargs=dict(
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=cu_query_lens,
cu_seqlens_k=cu_kv_lens,
max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len,
softmax_scale=scale,
causal=True,
window_size=window_size,
block_table=block_tables,
softcap=soft_cap if soft_cap is not None else 0,
),
test_utils=test_utils)
ref_output = ref_paged_attn(
query=query,
key_cache=key_cache,
......
......@@ -24,13 +24,14 @@ MNK_SHAPES = [
(1, 128, 128),
(1, 512, 1024),
(1, 4096, 4096),
(1, 8192, 28672),
(13, 8192, 4096),
(26, 4096, 8192),
(1, 4096, 4096),
(64, 4096, 4096),
(64, 8192, 28672),
(257, 128, 4096),
(257, 4224, 4160),
(257, 4096, 4096),
(64, 4096, 4096),
(1024, 4096, 8192),
(1024, 8192, 4096),
]
......
......@@ -5,6 +5,7 @@ from einops import rearrange, repeat
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn, selective_state_update)
from vllm.utils import seed_everything
......@@ -98,8 +99,8 @@ def selective_scan_ref(u,
delta_bias=None,
delta_softplus=False,
return_last_state=False,
position_indices=None,
prev_state=None):
prev_state=None,
final_state_out=None):
"""
u: r(B D L)
delta: r(B D L)
......@@ -139,12 +140,8 @@ def selective_scan_ref(u,
deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
if is_variable_C and C.dim() == 4:
C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
last_state = None
for i in range(u.shape[2]):
if position_indices is not None and position_indices[0, i] == 0:
x = deltaB_u[:, :, i]
else:
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
if not is_variable_C:
y = torch.einsum('bdn,dn->bd', x, C)
else:
......@@ -153,14 +150,17 @@ def selective_scan_ref(u,
else:
y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
if i == u.shape[2] - 1:
last_state = x
if final_state_out is None:
final_state_out = x
else:
final_state_out.copy_(x)
ys.append(y)
y = torch.stack(ys, dim=2) # (batch dim L)
out = y if D is None else y + u * rearrange(D, "d -> d 1")
if z is not None:
out = out * F.silu(z)
out = out.to(dtype=dtype_in)
return out if not return_last_state else (out, last_state)
return out if not return_last_state else (out, final_state_out)
def selective_scan_opcheck_fn(u,
......@@ -172,9 +172,11 @@ def selective_scan_opcheck_fn(u,
z=None,
delta_bias=None,
delta_softplus=False,
return_last_state=False,
position_indices=None,
prev_state=None):
cu_seq_len=None,
cache_indices=None,
has_initial_state=None,
ssm_states=None,
pad_slot_id=PAD_SLOT_ID):
"""if return_last_state is True, returns (out, last_state)
last_state has shape (batch, dim, dstate).
"""
......@@ -190,36 +192,27 @@ def selective_scan_opcheck_fn(u,
C = C.contiguous()
if z is not None and z.stride(-1) != 1:
z = z.contiguous()
if B.dim() == 3:
if B.dim() == 3 and cu_seq_len is None:
B = B.unsqueeze(1)
if C.dim() == 3:
if B.dim() == 2 and cu_seq_len is not None:
B = B.unsqueeze(0)
if C.dim() == 3 and cu_seq_len is None:
C = C.unsqueeze(1)
n_chunks = int((u.shape[-1] + 2048 - 1) / 2048)
x = torch.zeros((
u.shape[0],
u.shape[1],
n_chunks,
int(A.shape[1] * 2),
),
device=u.device,
dtype=torch.float32,
requires_grad=False)
x[:, :, 0, 0::2] = 1
if prev_state is not None:
x[:, :, 0, 1::2].copy_(prev_state)
if C.dim() == 2 and cu_seq_len is not None:
C = C.unsqueeze(0)
# Disable test_autograd_registration for now as it seems to trigger
# a bogus error.
opcheck(torch.ops._C.selective_scan_fwd,
(u, delta, A, B, C, D, z, delta_bias, delta_softplus,
position_indices, x),
(u, delta, A, B, C, D, z, delta_bias, delta_softplus, cu_seq_len,
cache_indices, has_initial_state, ssm_states, pad_slot_id),
test_utils=["test_schema", "test_faketensor"])
@pytest.mark.parametrize('wtype', [torch.float32])
@pytest.mark.parametrize('itype', [torch.float32])
@pytest.mark.parametrize('itype',
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize('seqlen', [128, 256, 512, 1024, 2048, 4096])
@pytest.mark.parametrize("return_last_state", [True])
@pytest.mark.parametrize('has_delta_bias', [True])
@pytest.mark.parametrize('delta_softplus', [True])
@pytest.mark.parametrize('has_z', [True])
......@@ -229,8 +222,8 @@ def selective_scan_opcheck_fn(u,
@pytest.mark.parametrize("is_variable_B", [True])
@pytest.mark.parametrize("scan_chunks", [1, 2, 3])
def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D,
has_z, has_delta_bias, delta_softplus,
return_last_state, seqlen, itype, wtype, scan_chunks):
has_z, has_delta_bias, delta_softplus, seqlen, itype,
wtype, scan_chunks):
if varBC_groups > 1 and (not is_variable_B or not is_variable_C):
pytest.skip() # This config is not applicable
device = 'cuda'
......@@ -243,10 +236,11 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D,
atolw = max(atolw, atol)
# set seed
seed_everything(0)
batch_size = 2
batch_size = 1
dim = 4
dstate = 8
A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype))
A_ref = A.clone()
if not is_variable_B:
B_shape = [dim, dstate]
elif varBC_groups == 1:
......@@ -256,6 +250,7 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D,
B = torch.randn(B_shape,
device=device,
dtype=wtype if not is_variable_B else itype)
B_ref = B.clone()
if not is_variable_C:
C_shape = [dim, dstate]
elif varBC_groups == 1:
......@@ -265,16 +260,25 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D,
C = torch.randn(C_shape,
device=device,
dtype=wtype if not is_variable_C else itype)
C_ref = C.clone()
D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None
D_ref = D.clone()
z = torch.randn(batch_size, dim, seqlen, device=device,
dtype=itype) if has_z else None
z_ref = z.clone() if has_z else None
delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)
) if has_delta_bias else None
u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype)
u_ref = u.clone()
delta = (0.5 *
torch.rand(batch_size, dim, seqlen, device=device, dtype=itype))
state = None
state_ref = None
delta_ref = delta.clone()
state_shape = (batch_size, u.shape[1], int(A.shape[1]))
state = torch.randn(state_shape,
device=u.device,
dtype=itype,
requires_grad=False)
state_ref = state.clone()
out = None
out_ref = None
outs = []
......@@ -294,40 +298,40 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D,
if has_z:
assert z is not None
_z = z[..., chunk_start:chunk_end]
out, *rest = selective_scan_fn(u[..., chunk_start:chunk_end],
delta[..., chunk_start:chunk_end],
A,
_B,
_C,
D,
z=_z,
delta_bias=delta_bias,
delta_softplus=delta_softplus,
return_last_state=return_last_state,
prev_state=state if c > 0 else None)
out = selective_scan_fn(
u[..., chunk_start:chunk_end],
state,
delta[..., chunk_start:chunk_end],
A,
_B,
_C,
D,
z=_z,
delta_bias=delta_bias,
delta_softplus=delta_softplus,
has_initial_state=torch.ones(batch_size,
device=u.device,
dtype=torch.bool) if c > 0 else None)
outs.append(out)
if return_last_state:
state = rest[0]
if len(outs) > 1:
out = torch.cat(outs, dim=-1)
out_ref, *rest = selective_scan_ref(u,
delta,
A,
B,
C,
D,
z=z,
delta_bias=delta_bias,
delta_softplus=delta_softplus,
return_last_state=return_last_state)
if return_last_state:
state_ref = rest[0]
out_ref, state_ref, *rest = selective_scan_ref(
u_ref,
delta_ref,
A_ref,
B_ref,
C_ref,
D_ref,
z=z_ref,
delta_bias=delta_bias,
delta_softplus=delta_softplus,
return_last_state=True)
assert out is not None and out_ref is not None
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
if return_last_state:
assert state is not None and state_ref is not None
assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
assert state is not None and state_ref is not None
assert torch.allclose(state, state_ref.to(itype), rtol=rtol, atol=atol)
selective_scan_opcheck_fn(u,
delta,
......@@ -335,10 +339,10 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D,
B,
C,
D,
z=z,
z,
delta_bias=delta_bias,
delta_softplus=delta_softplus,
return_last_state=return_last_state)
ssm_states=state)
@pytest.mark.parametrize("itype",
......@@ -391,12 +395,163 @@ def test_selective_state_update(dim, dstate, has_z, itype):
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize('wtype', [torch.float32])
@pytest.mark.parametrize('itype', [torch.float32])
@pytest.mark.parametrize('seqlen', [1, 128, 129, 256, 512, 1024, 2048, 4096])
@pytest.mark.parametrize("return_last_state", [True])
@pytest.mark.parametrize('has_delta_bias', [True])
@pytest.mark.parametrize('delta_softplus', [True])
@pytest.mark.parametrize('has_z', [True])
@pytest.mark.parametrize('has_D', [True])
@pytest.mark.parametrize("varBC_groups", [1, 2])
@pytest.mark.parametrize("is_variable_C", [True])
@pytest.mark.parametrize("is_variable_B", [True])
# tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize("with_padding", [False, True])
def test_selective_scan_varlen(with_padding, is_variable_B, is_variable_C,
varBC_groups, has_D, has_z, has_delta_bias,
delta_softplus, return_last_state, seqlen,
itype, wtype):
if varBC_groups > 1 and (not is_variable_B or not is_variable_C):
pytest.skip() # This config is not applicable
device = 'cuda'
rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
rtol, atol = 3e-2, 5e-2
rtolw, atolw = (1e-3, 1e-3)
if has_z: # If we have z, the errors on the weights seem higher
rtolw = max(rtolw, rtol)
atolw = max(atolw, atol)
# set seed
torch.random.manual_seed(0)
seqlens = []
batch_size = 4
if seqlen < 10:
batch_size = 1
padding = 3 if with_padding else 0
padded_batch_size = batch_size + padding
if with_padding and seqlen < padded_batch_size:
pytest.skip()
nsplits = padded_batch_size - 1
eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
seqlens.append(
torch.diff(
torch.cat(
[torch.tensor([-1]), eos_pos,
torch.tensor([seqlen - 1])])).tolist())
assert sum(seqlens[-1]) == seqlen
assert all(s > 0 for s in seqlens[-1])
total_entries = batch_size * 10
cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum],
dim=0).cuda()
dim = 4
dstate = 8
A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype))
A_ref = A.clone()
B_shape = [varBC_groups, dstate, seqlen]
B = torch.randn(B_shape,
device=device,
dtype=wtype if not is_variable_B else itype)
B_ref = B.clone()
C_shape = [varBC_groups, dstate, seqlen]
C = torch.randn(C_shape,
device=device,
dtype=wtype if not is_variable_C else itype)
C_ref = C.clone()
D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None
D_ref = D.clone()
z = torch.randn(dim, seqlen, device=device, dtype=itype)
z_ref = z.clone()
delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)
) if has_delta_bias else None
u = torch.randn(dim, seqlen, device=device, dtype=itype)
u_ref = u.clone()
delta = (0.5 * torch.rand(dim, seqlen, device=device, dtype=itype))
delta_ref = delta.clone()
out = None
out_ref = None
prev_state_shape = (total_entries, u.shape[0], int(A.shape[1]))
prev_state = torch.randn(prev_state_shape,
device=u.device,
dtype=itype,
requires_grad=False)
prev_state_ref = prev_state.clone()
state_indices = torch.randperm(total_entries,
dtype=torch.int32,
device=u.device)[:batch_size]
unused_states_bool = torch.ones(total_entries,
dtype=torch.bool,
device=device)
unused_states_bool[state_indices] = False
padded_state_indices = torch.concat([
state_indices,
torch.as_tensor(
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
],
dim=-1)
has_initial_state = torch.randint(0,
2, (cumsum.shape[0] - 1, ),
dtype=torch.bool,
device=u.device)
out = selective_scan_fn(u, prev_state, delta, A, B, C, D, z, delta_bias,
delta_softplus, cumsum, padded_state_indices,
has_initial_state)
outs_ref = []
splits = [
torch.split(var, seqlens[0], dim=-1)
for var in (u_ref, delta_ref, B_ref, C_ref, z_ref)
]
for i in range(len(seqlens[0])):
u_s, delta_s, B_s, C_s, z_s = [v[i].unsqueeze(0) for v in splits]
if padded_state_indices[i] == PAD_SLOT_ID:
continue
out_ref_s, _ = selective_scan_ref(
u_s,
delta_s,
A_ref,
B_s,
C_s,
D_ref,
z=z_s,
delta_bias=delta_bias,
delta_softplus=delta_softplus,
return_last_state=return_last_state,
prev_state=prev_state_ref[padded_state_indices[i]].unsqueeze(0)
if has_initial_state[i] else None,
final_state_out=prev_state_ref[padded_state_indices[i]].unsqueeze(
0))
outs_ref.append(out_ref_s)
out_ref = torch.cat(outs_ref, dim=-1)[0]
unpadded_out = out[:, :out_ref[0].shape[-1]]
print("Output diff max", (unpadded_out - out_ref).max())
print("Output diff mean", (unpadded_out - out_ref).mean())
print("Output state diff max", (prev_state - prev_state_ref).max())
print("Output state diff mean", (prev_state - prev_state_ref).mean())
assert torch.allclose(prev_state, prev_state_ref, rtol=rtol, atol=atol)
assert torch.allclose(unpadded_out, out_ref, rtol=rtol, atol=atol)
selective_scan_opcheck_fn(u, delta, A, B, C, D, z, delta_bias,
delta_softplus, cumsum, padded_state_indices,
has_initial_state, prev_state)
@pytest.mark.parametrize("itype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("has_z", [False, True])
@pytest.mark.parametrize("has_z", [True])
@pytest.mark.parametrize("dstate", [16, 32, 64])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype):
# tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize("with_padding", [True, False])
def test_selective_state_update_with_batch_indices(with_padding, dim, dstate,
has_z, itype):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
if itype == torch.bfloat16:
......@@ -405,22 +560,33 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype):
atol *= 2
# set seed
torch.random.manual_seed(0)
batch_size = 16
batch_size = 3
padding = 5 if with_padding else 0
padded_batch_size = batch_size + padding
total_entries = 10 * batch_size
state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device)
state_indices = torch.randperm(total_entries)[:batch_size].to(
dtype=torch.int32, device=device)
x = torch.randn(batch_size, dim, device=device, dtype=itype)
dt = torch.randn(batch_size, dim, device=device, dtype=itype)
unused_states_bool = torch.ones(total_entries,
dtype=torch.bool,
device=device)
unused_states_bool[state_indices] = False
padded_state_indices = torch.concat([
state_indices,
torch.as_tensor(
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device)
],
dim=0)
x = torch.randn(padded_batch_size, dim, device=device, dtype=itype)
dt = torch.randn(padded_batch_size, dim, device=device, dtype=itype)
dt_bias = torch.rand(dim, device=device) - 4.0
A = -torch.rand(dim, dstate, device=device) - 1.0
B = torch.randn(batch_size, dstate, device=device)
C = torch.randn(batch_size, dstate, device=device)
B = torch.randn(padded_batch_size, dstate, device=device)
C = torch.randn(padded_batch_size, dstate, device=device)
D = torch.randn(dim, device=device)
z = torch.randn_like(x) if has_z else None
state_ref = state[state_indices, :].detach().clone()
state_ref = state[state_indices, :].clone()
state_before = state.clone()
out = selective_state_update(state,
x,
dt,
......@@ -431,23 +597,39 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype):
z=z,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=state_indices)
state_batch_indices=padded_state_indices,
pad_slot_id=PAD_SLOT_ID)
out_ref = selective_state_update_ref(state_ref,
x,
dt,
x[:batch_size],
dt[:batch_size],
A,
B,
C,
B[:batch_size],
C[:batch_size],
D=D,
z=z,
z=z[:batch_size],
dt_bias=dt_bias,
dt_softplus=True)
print("Output diff max", (out - out_ref[0]).max())
print("Output diff mean", (out - out_ref[0]).mean())
print("Output state diff max", (state[state_indices, :] - state_ref).max())
print("Output state diff mean",
(state[state_indices, :] - state_ref).mean())
# test padded entries stay the same
if with_padding:
assert torch.equal(state_before[unused_states_bool],
state[unused_states_bool])
assert torch.equal(x[batch_size + 1:], x[batch_size + 1:])
assert torch.equal(dt[batch_size + 1:], dt[batch_size + 1:])
assert torch.equal(B[batch_size + 1:], B[batch_size + 1:])
assert torch.equal(C[batch_size + 1:], C[batch_size + 1:])
# test "real" entries
assert torch.allclose(state[state_indices, :],
state_ref,
rtol=rtol,
atol=atol)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize("itype",
......@@ -465,7 +647,7 @@ def test_selective_state_update_with_heads_with_batch_indices(
rtol, atol = 1e-1, 1e-1
# set seed
torch.random.manual_seed(0)
batch_size = 16
batch_size = 3
headdim = 64
nheads = dim // headdim
......@@ -516,7 +698,8 @@ def test_selective_state_update_with_heads_with_batch_indices(
z=z,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=state_indices)
state_batch_indices=state_indices,
pad_slot_id=PAD_SLOT_ID)
out_ref = selective_state_update_ref(state_ref,
x,
dt,
......
......@@ -2,16 +2,14 @@
Run `pytest tests/kernels/test_moe.py`.
"""
from typing import List
import pytest
import torch
from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
from tests.kernels.utils import opcheck
from tests.kernels.utils import (compute_max_diff, opcheck, stack_and_dev,
torch_moe, torch_moe_single)
from vllm import _custom_ops as ops
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe, single_marlin_moe)
......@@ -24,37 +22,6 @@ from vllm.scalar_type import scalar_types
from vllm.utils import seed_everything
def torch_moe(a, w1, w2, score, topk):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
out[mask] = SiluAndMul()(
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
return (out.view(B, -1, w2.shape[1]) *
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
def torch_moe_single(a, w, score, topk):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
_, topk_ids = torch.topk(score, topk)
topk_ids = topk_ids.view(-1)
for i in range(w.shape[0]):
mask = topk_ids == i
if mask.sum():
out[mask] = a[mask] @ w[i].transpose(0, 1)
return (out.view(B, -1, w.shape[1])).sum(dim=1)
@pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 511, 1024])
......@@ -127,24 +94,15 @@ def test_mixtral_moe(dtype: torch.dtype):
atol=mixtral_moe_tol[dtype])
def stack_and_dev(tensors: List[torch.Tensor]):
dev = tensors[0].device
return torch.stack(tensors, dim=0).to(dev)
def compute_max_diff(output, output_ref):
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
torch.abs(output_ref))
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 1024, 512])
@pytest.mark.parametrize("e", [4, 8, 64])
@pytest.mark.parametrize("e", [8, 64])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("act_order", [True, False])
@pytest.mark.parametrize("num_bits", [4, 8])
@pytest.mark.parametrize("is_k_full", [True, False])
def test_fused_marlin_moe(
m: int,
n: int,
......@@ -154,18 +112,19 @@ def test_fused_marlin_moe(
group_size: int,
act_order: bool,
num_bits: int,
is_k_full: bool,
):
seed_everything(7)
if topk > e:
return
# Filter act_order
if act_order:
if group_size == -1:
return
if group_size in (k, n):
return
else:
if not is_k_full:
return
quant_type = (scalar_types.uint4b8
if num_bits == 4 else scalar_types.uint8b128)
......@@ -236,16 +195,17 @@ def test_fused_marlin_moe(
a,
qweight1,
qweight2,
scales1,
scales2,
score,
g_idx1,
g_idx2,
sort_indices1,
sort_indices2,
topk_weights,
topk_ids,
w1_scale=scales1,
w2_scale=scales2,
g_idx1=g_idx1,
g_idx2=g_idx2,
sort_indices1=sort_indices1,
sort_indices2=sort_indices2,
num_bits=num_bits,
is_k_full=is_k_full,
)
assert compute_max_diff(marlin_output, triton_output) < 4e-2
......@@ -274,9 +234,13 @@ def test_fused_marlin_moe(
device="cuda",
requires_grad=False)
zp = torch.empty((0, 0),
dtype=dtype,
device="cuda",
requires_grad=False)
opcheck(torch.ops._moe_C.marlin_gemm_moe,
(a, qweight1, sorted_token_ids, topk_weights, topk_ids,
scales1, g_idx1, sort_indices1, workspace, quant_type, m,
scales1, zp, g_idx1, sort_indices1, workspace, quant_type, m,
2 * n, k, True, e, topk, block_size_m, True, False))
......@@ -285,11 +249,12 @@ def test_fused_marlin_moe(
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 1024, 512])
@pytest.mark.parametrize("e", [4, 8, 64])
@pytest.mark.parametrize("e", [8, 64])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("act_order", [True, False])
@pytest.mark.parametrize("num_bits", [4, 8])
@pytest.mark.parametrize("is_k_full", [True, False])
def test_single_marlin_moe_multiply(
m: int,
n: int,
......@@ -299,9 +264,8 @@ def test_single_marlin_moe_multiply(
group_size: int,
act_order: bool,
num_bits: int,
is_k_full: bool,
):
if topk > e:
return
# Filter act_order
if act_order:
......@@ -309,6 +273,9 @@ def test_single_marlin_moe_multiply(
return
if group_size == k:
return
else:
if not is_k_full:
return
quant_type = (scalar_types.uint4b8
if num_bits == 4 else scalar_types.uint8b128)
......@@ -339,15 +306,19 @@ def test_single_marlin_moe_multiply(
sort_indices = stack_and_dev(sort_indices_l)
score = torch.randn((m, e), device="cuda", dtype=dtype)
marlin_output = single_marlin_moe(a,
qweight,
scales,
score,
g_idx,
sort_indices,
topk,
renormalize=False,
num_bits=num_bits)
marlin_output = single_marlin_moe(
a,
qweight,
scales,
score,
topk,
renormalize=False,
g_idx=g_idx,
sort_indices=sort_indices,
num_bits=num_bits,
is_k_full=is_k_full,
)
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)
assert compute_max_diff(marlin_output, torch_output) < 1e-2
......
......@@ -105,7 +105,7 @@ def test_batched_rotary_embedding(
if rotary_dim is None:
rotary_dim = head_size
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, {
"type": "linear",
"rope_type": "linear",
"factor": (1, )
})
rope = rope.to(dtype=dtype)
......@@ -166,7 +166,7 @@ def test_batched_rotary_embedding_multi_lora(
rotary_dim = head_size
scaling_factors: List[int] = [1, 2, 4]
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, {
"type": "linear",
"rope_type": "linear",
"factor": tuple(scaling_factors)
})
rope = rope.to(dtype=dtype)
......@@ -211,10 +211,10 @@ def test_rope_module_cache():
MAX_POSITIONS = [123, 1234]
BASES = [10000, 1000000]
ROPE_SCALINGS = (None, {
"type": "linear",
"rope_type": "linear",
"factor": (1, )
}, {
"type": "dynamic",
"rope_type": "dynamic",
"factor": 1
})
settings = (HEAD_SIZES, ROTARY_DIMS, MAX_POSITIONS, BASES, IS_NEOX_STYLE,
......
......@@ -12,6 +12,7 @@ import torch
from torch._prims_common import TensorLikeType
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL,
make_tensor_with_pad)
......@@ -974,6 +975,50 @@ def fp8_allclose(
equal_nan=equal_nan)).item())
# Marlin MoE test utils
def stack_and_dev(tensors: List[torch.Tensor]):
dev = tensors[0].device
return torch.stack(tensors, dim=0).to(dev)
def compute_max_diff(output, output_ref):
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
torch.abs(output_ref))
def torch_moe(a, w1, w2, score, topk):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
out[mask] = SiluAndMul()(
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
return (out.view(B, -1, w2.shape[1]) *
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
def torch_moe_single(a, w, score, topk):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
_, topk_ids = torch.topk(score, topk)
topk_ids = topk_ids.view(-1)
for i in range(w.shape[0]):
mask = topk_ids == i
if mask.sum():
out[mask] = a[mask] @ w[i].transpose(0, 1)
return (out.view(B, -1, w.shape[1])).sum(dim=1)
# A special version of op check that has a restricted default set of test_utils
# and a patched version of allclose that supports fp8 types.
def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,
......
......@@ -173,6 +173,11 @@ def mixtral_lora_files():
return snapshot_download(repo_id="SangBinCho/mixtral-lora")
@pytest.fixture(scope="session")
def mixtral_lora_files_all_target_modules():
return snapshot_download(repo_id="dyang415/mixtral-lora-v0")
@pytest.fixture(scope="session")
def gemma_lora_files():
return snapshot_download(repo_id="wskwon/gemma-7b-test-lora")
......@@ -194,6 +199,16 @@ def baichuan_zero_lora_files():
return snapshot_download(repo_id="jeeejeee/baichuan7b-zero-init")
@pytest.fixture(scope="session")
def baichuan_regex_lora_files():
return snapshot_download(repo_id="jeeejeee/baichuan-7b-lora-zero-regex")
@pytest.fixture(scope="session")
def minicpmv_lora_files():
return snapshot_download(repo_id="jeeejeee/minicpmv25-lora-pokemon")
@pytest.fixture(scope="session")
def tinyllama_lora_files():
return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")
......
......@@ -63,12 +63,11 @@ def test_baichuan_lora(baichuan_lora_files):
assert output2[i] == expected_lora_output[i]
@pytest.mark.skip("Requires multiple GPUs")
@pytest.mark.parametrize("fully_sharded", [True, False])
def test_baichuan_tensor_parallel_equality(baichuan_lora_files, fully_sharded):
# Cannot use as it will initialize torch.cuda too early...
# if torch.cuda.device_count() < 4:
# pytest.skip(f"Not enough GPUs for tensor parallelism {4}")
def test_baichuan_tensor_parallel_equality(baichuan_lora_files,
num_gpus_available, fully_sharded):
if num_gpus_available < 4:
pytest.skip(f"Not enough GPUs for tensor parallelism {4}")
llm_tp1 = vllm.LLM(MODEL_PATH,
enable_lora=True,
......
......@@ -951,7 +951,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
lora_rope.create_lora_weights(max_loras, lora_config)
linear_rope = get_rope(head_size, rotary_dim, max_position, base,
is_neox_style, {
"type": "linear",
"rope_type": "linear",
"factor": scaling_factors
})
linear_rope = linear_rope.to(dtype=dtype)
......
......@@ -5,7 +5,9 @@ import pytest
from vllm.lora.models import LoRAModel
from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM
lora_lst = ["baichuan7B", "baichuan7B-zero", "chatglm3-6b"]
lora_lst = [
"baichuan7B", "baichuan7B-zero", "baichuan7B-zero-regex", "chatglm3-6b"
]
@pytest.mark.parametrize("lora_name", lora_lst)
......@@ -13,6 +15,7 @@ def test_load_checkpoints(
lora_name,
baichuan_lora_files,
baichuan_zero_lora_files,
baichuan_regex_lora_files,
chatglm3_lora_files,
):
supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules
......@@ -36,7 +39,7 @@ def test_load_checkpoints(
embedding_modules=embedding_modules,
embedding_padding_modules=embed_padding_modules)
elif lora_name == "baichuan7B-zero":
#Test that the target_modules contain prefix
# Test that the target_modules contain prefix
# such as "model.layers.0.self_atten.W_pack", and
# the test should pass.
LoRAModel.from_local_checkpoint(
......@@ -46,6 +49,16 @@ def test_load_checkpoints(
device="cpu",
embedding_modules=embedding_modules,
embedding_padding_modules=embed_padding_modules)
elif lora_name == "baichuan7B-zero-regex":
# Test that the `target_modules` in the form of regular expressions,
# such as `model\\..*(W_pack|o_proj)`, and the test should pass.
LoRAModel.from_local_checkpoint(
baichuan_regex_lora_files,
expected_lora_modules,
lora_model_id=1,
device="cpu",
embedding_modules=embedding_modules,
embedding_padding_modules=embed_padding_modules)
else:
# For the baichuan7B model, load chatglm3-6b's LoRA,
# and the test should raise the following error.
......
from typing import List
import vllm
from vllm.assets.image import ImageAsset
from vllm.lora.request import LoRARequest
MODEL_PATH = "openbmb/MiniCPM-Llama3-V-2_5"
PROMPT_TEMPLATE = (
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"
"(<image>./</image>)\nWhat is in the image?<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n")
IMAGE_ASSETS = [
ImageAsset("stop_sign"),
ImageAsset("cherry_blossom"),
]
# After fine-tuning with LoRA, all generated content should start begin `A`.
EXPECTED_OUTPUT = [
"A red and white stop sign with a Chinese archway in the background featuring red lanterns and gold accents.", # noqa: E501
"A pink cherry blossom tree with a blue sky in the background.",
]
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
sampling_params = vllm.SamplingParams(
temperature=0,
max_tokens=5,
stop_token_ids=[128001, 128009], # eos_id, eot_id
)
inputs = [{
"prompt": PROMPT_TEMPLATE,
"multi_modal_data": {
"image": asset.pil_image
},
} for asset in IMAGE_ASSETS]
outputs = llm.generate(
inputs,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None,
)
# Print the outputs.
generated_texts: List[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
def test_minicpmv_lora(minicpmv_lora_files):
llm = vllm.LLM(
MODEL_PATH,
max_num_seqs=2,
enable_lora=True,
max_loras=4,
max_lora_rank=64,
trust_remote_code=True,
)
output1 = do_sample(llm, minicpmv_lora_files, lora_id=1)
for i in range(len(EXPECTED_OUTPUT)):
assert EXPECTED_OUTPUT[i].startswith(output1[i])
output2 = do_sample(llm, minicpmv_lora_files, lora_id=2)
for i in range(len(EXPECTED_OUTPUT)):
assert EXPECTED_OUTPUT[i].startswith(output2[i])
from typing import List
import pytest
import vllm
from vllm.assets.image import ImageAsset
from vllm.lora.request import LoRARequest
from ..utils import multi_gpu_test
MODEL_PATH = "openbmb/MiniCPM-Llama3-V-2_5"
PROMPT_TEMPLATE = (
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"
"(<image>./</image>)\nWhat is in the image?<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n")
IMAGE_ASSETS = [
ImageAsset("stop_sign"),
ImageAsset("cherry_blossom"),
]
# After fine-tuning with LoRA, all generated content should start begin `A`.
EXPECTED_OUTPUT = [
"A red and white stop sign with a Chinese archway in the background featuring red lanterns and gold accents.", # noqa: E501
"A pink cherry blossom tree with a blue sky in the background.",
]
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
sampling_params = vllm.SamplingParams(
temperature=0,
max_tokens=5,
stop_token_ids=[128001, 128009], # eos_id, eot_id
)
inputs = [{
"prompt": PROMPT_TEMPLATE,
"multi_modal_data": {
"image": asset.pil_image
},
} for asset in IMAGE_ASSETS]
outputs = llm.generate(
inputs,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None,
)
# Print the outputs.
generated_texts: List[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("fully_sharded", [True, False])
def test_minicpmv_tp2(minicpmv_lora_files, fully_sharded):
llm = vllm.LLM(
MODEL_PATH,
enable_lora=True,
max_num_seqs=2,
max_loras=4,
max_lora_rank=64,
tensor_parallel_size=2,
trust_remote_code=True,
fully_sharded_loras=fully_sharded,
)
output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1)
for i in range(len(EXPECTED_OUTPUT)):
assert EXPECTED_OUTPUT[i].startswith(output_tp[i])
@multi_gpu_test(num_gpus=4)
@pytest.mark.parametrize("fully_sharded", [True, False])
def test_minicpmv_tp4(minicpmv_lora_files, fully_sharded):
llm = vllm.LLM(
MODEL_PATH,
enable_lora=True,
max_num_seqs=2,
max_loras=4,
max_lora_rank=64,
tensor_parallel_size=4,
trust_remote_code=True,
fully_sharded_loras=fully_sharded,
)
output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1)
for i in range(len(EXPECTED_OUTPUT)):
assert EXPECTED_OUTPUT[i].startswith(output_tp[i])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment