Unverified Commit 2c11a738 authored by Congcong Chen's avatar Congcong Chen Committed by GitHub
Browse files

[Model] New model support for microsoft/Phi-4-mini-flash-reasoning (#20702)


Signed-off-by: default avatarCongcong Chen <congcongchen@microsoft.com>
parent b639327a
......@@ -312,19 +312,20 @@ void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {
// kIsVariableB, kIsVariableC and kHasZ are all set to True to reduce binary size
constexpr bool kIsVariableB = true;
constexpr bool kIsVariableC = true;
constexpr bool kHasZ = true;
BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] {
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kVarlen, input_t, weight_t>;
constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
dim3 grid(params.batch, params.dim / kNRows);
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
if (kSmemSize >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
}
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] {
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kVarlen, input_t, weight_t>;
constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
dim3 grid(params.batch, params.dim / kNRows);
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
if (kSmemSize >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
}
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
}
......@@ -612,19 +613,20 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
at::Tensor z, out_z;
const bool has_z = z_.has_value();
TORCH_CHECK(has_z, "has_z = False is disabled in favor of reduced binary size")
z = z_.value();
TORCH_CHECK(z.scalar_type() == input_type);
TORCH_CHECK(z.is_cuda());
TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
if (varlen){
CHECK_SHAPE(z, dim, seqlen);
} else {
CHECK_SHAPE(z, batch_size, dim, seqlen);
if (has_z) {
z = z_.value();
TORCH_CHECK(z.scalar_type() == input_type);
TORCH_CHECK(z.is_cuda());
TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
if (varlen){
CHECK_SHAPE(z, dim, seqlen);
} else {
CHECK_SHAPE(z, batch_size, dim, seqlen);
}
out_z = z;
}
out_z = z;
// Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
at::Tensor out = delta;
TORCH_CHECK(ssm_states.scalar_type() == input_type);
......@@ -653,4 +655,3 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
selective_scan_fwd_cuda<input_t, weight_t>(params, stream);
});
}
......@@ -374,6 +374,7 @@ Specified using `--task generate`.
| `Phi3ForCausalLM` | Phi-4, Phi-3 | `microsoft/Phi-4-mini-instruct`, `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Phi3SmallForCausalLM` | Phi-3-Small | `microsoft/Phi-3-small-8k-instruct`, `microsoft/Phi-3-small-128k-instruct`, etc. | | ✅︎ | ✅︎ |
| `PhiMoEForCausalLM` | Phi-3.5-MoE | `microsoft/Phi-3.5-MoE-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Phi4FlashForCausalLM` | Phi-4-mini-flash-reasoning | `microsoft/microsoft/Phi-4-mini-instruct`, etc. | | | |
| `PersimmonForCausalLM` | Persimmon | `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc. | | ✅︎ | ✅︎ |
| `Plamo2ForCausalLM` | PLaMo2 | `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. | | | |
| `QWenLMHeadModel` | Qwen | `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |
......
......@@ -248,6 +248,10 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"Phi3SmallForCausalLM": _HfExamplesInfo("microsoft/Phi-3-small-8k-instruct",
trust_remote_code=True,
v0_only=True),
"Phi4FlashForCausalLM": _HfExamplesInfo("microsoft/Phi-4-mini-flash-reasoning", # noqa: E501
trust_remote_code=True,
v0_only=True,
max_model_len=10240),
"PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct",
trust_remote_code=True),
"Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b",
......
......@@ -103,6 +103,9 @@ def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
_initialize_kv_caches_v1), monkeypatch.context() as m):
if model_info.v0_only:
m.setenv("VLLM_USE_V1", "0")
if model_arch == "Phi4FlashForCausalLM":
# Phi4FlashForCausalLM only supports DIFFERENTIAL_FLASH_ATTN backend
m.setenv("VLLM_ATTENTION_BACKEND", "DIFFERENTIAL_FLASH_ATTN")
LLM(
model_info.default,
tokenizer=model_info.tokenizer,
......
......@@ -458,6 +458,31 @@ def test_bind_kv_cache():
assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[2]
assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[3]
def test_bind_kv_cache_kv_sharing():
from vllm.attention import Attention
ctx = {
'layers.0.self_attn': Attention(32, 128, 0.1),
'layers.1.self_attn': Attention(32, 128, 0.1),
'layers.2.self_attn': Attention(32, 128, 0.1),
'layers.3.self_attn': Attention(32, 128, 0.1),
}
kv_cache = [
torch.zeros((1, )),
torch.zeros((1, )),
torch.zeros((1, )),
torch.zeros((1, )),
]
shared_kv_cache_layers = {
'layers.2.self_attn': 'layers.1.self_attn',
'layers.3.self_attn': 'layers.0.self_attn'
}
bind_kv_cache(ctx, [kv_cache], shared_kv_cache_layers)
assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[0]
assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[1]
assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[1]
assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[0]
def test_bind_kv_cache_non_attention():
from vllm.attention import Attention
......
......@@ -308,7 +308,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
kv_sharing_target_layer_name: Optional[str] = None,
) -> None:
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0.")
raise NotImplementedError("KV sharing is not supported in V0 "
"BLOCK_SPARSE_FLASH_ATTN Backend.")
assert blocksparse_params is not None
assert alibi_slopes is None, ValueError(
"Alibi not support for blocksparse flash attention.")
......
This diff is collapsed.
......@@ -295,7 +295,8 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
dual_chunk_attention_config: Optional[Dict[str, Any]] = None,
) -> None:
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0.")
raise NotImplementedError("KV sharing is not supported in V0 "
"DUAL_CHUNK_FLASH_ATTN backend.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
......
......@@ -622,7 +622,8 @@ class FlashAttentionImpl(AttentionImpl):
use_irope: bool = False,
) -> None:
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0.")
raise NotImplementedError("KV sharing is not supported in V0 "
"FLASH_ATTN backend.")
if blocksparse_params is not None:
raise ValueError(
"FlashAttention does not support block-sparse attention.")
......
......@@ -1006,7 +1006,8 @@ class FlashInferImpl(AttentionImpl):
use_irope: bool = False,
) -> None:
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0.")
raise NotImplementedError("KV sharing is not supported in V0 "
"FLASHINFER backend.")
if use_irope:
logger.warning_once(
"Using irope in FlashInfer is not supported yet, it will fall"
......
......@@ -115,7 +115,8 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
) -> None:
super(AttentionImpl, self).__init__()
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0.")
raise NotImplementedError("KV sharing is not supported in V0 "
"HPU_ATTN backend.")
if use_irope:
logger.warning_once(
"Using irope in HPU is not supported yet, it will fall back "
......
......@@ -501,7 +501,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
use_irope: bool = False,
) -> None:
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0.")
raise NotImplementedError("KV sharing is not supported in V0 "
"ROCM_FLASH backend.")
if use_irope:
logger.warning_once(
"Using irope in ROCm Flash Attention is not supported yet, it "
......
......@@ -394,7 +394,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
use_irope: bool = False,
) -> None:
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0.")
raise NotImplementedError("KV sharing is not supported in V0 "
"XFORMERS backend.")
if blocksparse_params is not None:
raise ValueError(
"XFormers does not support block-sparse attention.")
......
......@@ -160,10 +160,6 @@ class Attention(nn.Module):
self.attn_type = attn_type
if kv_sharing_target_layer_name is not None:
if not envs.VLLM_USE_V1:
raise NotImplementedError(
"Cross-layer KV sharing is not supported in V0.")
validate_kv_sharing_target(
prefix,
kv_sharing_target_layer_name,
......
......@@ -59,11 +59,12 @@ class LogitsProcessor(nn.Module):
hidden_states: torch.Tensor,
sampling_metadata: Optional[SamplingMetadata] = None,
embedding_bias: Optional[torch.Tensor] = None,
prune_hidden_states: bool = True,
) -> Optional[torch.Tensor]:
if self.logits_as_input:
logits = hidden_states
else:
if sampling_metadata is not None:
if sampling_metadata is not None and prune_hidden_states:
hidden_states = _prune_hidden_states(hidden_states,
sampling_metadata)
......
This diff is collapsed.
......@@ -110,6 +110,7 @@ _TEXT_GENERATION_MODELS = {
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
"Phi4FlashForCausalLM": ("phi4flash", "Phi4FlashForCausalLM"),
"Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
......
......@@ -316,6 +316,10 @@ class CudaPlatformBase(Platform):
logger.info("Using DualChunkFlashAttention backend.")
return ("vllm.attention.backends.dual_chunk_flash_attn."
"DualChunkFlashAttentionBackend")
elif selected_backend == _Backend.DIFFERENTIAL_FLASH_ATTN:
logger.info("Using DifferentialFlashAttention backend.")
return ("vllm.attention.backends.differential_flash_attn."
"DifferentialFlashAttentionBackend")
elif selected_backend == _Backend.FLASH_ATTN:
pass
elif selected_backend:
......
......@@ -60,6 +60,7 @@ class _Backend(enum.Enum):
IPEX = enum.auto()
BLOCK_SPARSE_FLASH_ATTN = enum.auto()
DUAL_CHUNK_FLASH_ATTN = enum.auto()
DIFFERENTIAL_FLASH_ATTN = enum.auto()
NO_ATTENTION = enum.auto()
FLEX_ATTENTION = enum.auto()
......
......@@ -2888,8 +2888,9 @@ def get_mp_context():
def bind_kv_cache(
ctx: dict[str, Any],
kv_cache: list[list[torch.Tensor]], # [virtual_engine][layer_index]
ctx: dict[str, Any],
kv_cache: list[list[torch.Tensor]], # [virtual_engine][layer_index]
shared_kv_cache_layers: Optional[dict[str, str]] = None
) -> None:
# Bind the kv_cache tensor to Attention modules, similar to
# ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)]
......@@ -2901,12 +2902,17 @@ def bind_kv_cache(
# attention of the same layer (e.g., bart's decoder.layers.1.self_attn
# and decoder.layers.1.encoder_attn) is mapped to the same kv cache
# tensor
# 5. Some models have attention layers that share kv cache with previous
# layers, this is specified through shared_kv_cache_layers
if shared_kv_cache_layers is None:
shared_kv_cache_layers = {}
from vllm.attention import AttentionType
from vllm.model_executor.models.utils import extract_layer_index
layer_need_kv_cache = [
layer_name for layer_name in ctx
if (hasattr(ctx[layer_name], 'attn_type') and ctx[layer_name].attn_type
in (AttentionType.DECODER, AttentionType.ENCODER_DECODER))
in (AttentionType.DECODER, AttentionType.ENCODER_DECODER)) \
and ctx[layer_name].kv_sharing_target_layer_name is None
]
layer_index_sorted = sorted(
set(
......@@ -2919,6 +2925,12 @@ def bind_kv_cache(
assert len(forward_ctx.kv_cache) == len(kv_cache)
for ve, ve_kv_cache in enumerate(kv_cache):
forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx]
if shared_kv_cache_layers is not None:
for layer_name, target_layer_name in shared_kv_cache_layers.items():
assert extract_layer_index(target_layer_name) < \
extract_layer_index(layer_name), \
"v0 doesn't support interleaving kv sharing"
ctx[layer_name].kv_cache = ctx[target_layer_name].kv_cache
def run_method(obj: Any, method: Union[str, bytes, Callable], args: tuple[Any],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment