Commit 858bddce authored by luopl's avatar luopl
Browse files

feat:add gemma4

parent 40faaf0c
......@@ -13,6 +13,7 @@ from .dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding
from .dynamic_ntk_scaling_rope import DynamicNTKScalingRotaryEmbedding
from .fope import FourierRotaryEmbedding
from .linear_scaling_rope import LinearScalingRotaryEmbedding
from .gemma4_rope import Gemma4RotaryEmbedding
from .llama3_rope import Llama3RotaryEmbedding
from .llama4_vision_rope import Llama4VisionRotaryEmbedding
from .mrope import MRotaryEmbedding
......@@ -134,6 +135,17 @@ def get_rope(
is_neox_style,
dtype,
)
elif scaling_type == "proportional":
# Proportional RoPE is used by Gemma4 for global (full) attention.
# Gemma4 uses a sparse/fractional RoPE with cross-mixing between halves.
rotary_emb = Gemma4RotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
)
elif scaling_type == "llama3":
scaling_factor = rope_parameters["factor"]
low_freq_factor = rope_parameters["low_freq_factor"]
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Gemma4-specific Rotary Positional Embeddings (proportional scaling).
Gemma4 uses "proportional" RoPE which computes inv_freq frequencies scaled
by head_dim (not rotary_dim), and zero-pads for non-rotated dimensions when
partial_rotary_factor < 1. The actual rotation uses standard neox-style
rotate_half, matching HF transformers' apply_rotary_pos_emb.
"""
import torch
from .base import RotaryEmbedding
class Gemma4RotaryEmbedding(RotaryEmbedding):
"""Gemma4 proportional RoPE.
Extends RotaryEmbedding (which provides standard neox-style rotation
via ops.rotary_embedding CUDA kernel) but overrides the inv_freq
computation to match HF's _compute_proportional_rope_parameters:
- Frequency exponents use head_dim (not rotary_dim) as denominator
- Non-rotated dims are zero-padded (cos=1, sin=0 = identity rotation)
When partial_rotary_factor=1.0 (the default for some variants), ALL dims are
rotated and this is equivalent to standard RotaryEmbedding with
head_dim-scaled frequencies.
"""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
dtype: torch.dtype,
) -> None:
# Number of rotation angle pairs (from partial_rotary_factor)
self.rope_angles = rotary_dim // 2
# Non-rotated angle pairs per half
self.nope_angles = (head_size // 2) - self.rope_angles
# Important: set rotary_dim = head_size so the base class's
# forward_static applies rotation to ALL dims of the cos/sin cache.
# The non-rotated dims will have cos=1, sin=0 (identity) thanks
# to our _compute_inv_freq zero-padding.
super().__init__(
head_size,
head_size, # rotary_dim = head_size (full application)
max_position_embeddings,
base,
is_neox_style,
dtype,
)
def _compute_inv_freq(self, base: float) -> torch.Tensor:
"""Compute frequencies matching HF proportional RoPE.
Key difference from base: exponent denominator is head_size (not
rotary_dim), and non-rotated dims are zero-padded.
"""
# HF formula: base ** (arange(0, 2*rope_angles, 2) / head_dim)
freq_exponents = (
torch.arange(0, 2 * self.rope_angles, 2, dtype=torch.float) / self.head_size
)
inv_freq = 1.0 / (base**freq_exponents)
# Zero-pad for non-rotated dims (identity rotation: cos=1, sin=0)
if self.nope_angles > 0:
inv_freq = torch.cat(
[
inv_freq,
torch.zeros(self.nope_angles, dtype=torch.float),
]
)
return inv_freq
def extra_repr(self) -> str:
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
s += f", rope_angles={self.rope_angles}, nope_angles={self.nope_angles}"
s += f", max_position_embeddings={self.max_position_embeddings}"
s += f", base={self.base}, is_neox_style={self.is_neox_style}"
return s
......@@ -56,6 +56,57 @@ class Gemma3TextModelConfig(VerifyAndUpdateConfig):
hf_config = model_config.hf_config
hf_config.is_causal = not hf_config.use_bidirectional_attention
class Gemma4Config(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
"""Force unified attention backend for models with heterogeneous
head dimensions.
Some Gemma4 variants use different head dimensions for
sliding window (head_dim) vs full attention (global_head_dim) layers.
When global_head_dim > 256, FlashAttention rejects those layers
(head_size <= 256 kernel limit), causing vLLM to select a different
backend for each layer type. This mixed-backend execution produces
numerical divergence and output corruption.
The fix detects heterogeneous head dimensions from the model config
and forces TRITON_ATTN (which has no head_size ceiling) for all
layers when the user hasn't explicitly chosen a backend.
TODO: Heterogeneous head_sizes (head_dim != global_head_dim)
require NixlConnector changes to support per-layer KV transfer
with different head dimensions for prefill-decode disaggregation.
"""
hf_text_config = vllm_config.model_config.hf_text_config
head_dim = getattr(hf_text_config, "head_dim", None)
global_head_dim = getattr(hf_text_config, "global_head_dim", None)
# Only force Triton when head dimensions actually differ AND the
# larger one exceeds FlashAttention's kernel limit (head_size <= 256).
# This avoids unnecessary backend forcing on smaller models where
# the config carries global_head_dim but all layers can still use
# the same FA backend.
max_head_dim = max(head_dim or 0, global_head_dim or 0)
if (
head_dim is not None
and global_head_dim is not None
and head_dim != global_head_dim
and max_head_dim > 256
and vllm_config.attention_config.backend is None
):
from vllm.v1.attention.backends.registry import (
AttentionBackendEnum,
)
vllm_config.attention_config.backend = AttentionBackendEnum.TRITON_ATTN
logger.info(
"Gemma4 model has heterogeneous head dimensions "
"(head_dim=%d, global_head_dim=%d). Forcing TRITON_ATTN "
"backend to prevent mixed-backend numerical divergence.",
head_dim,
global_head_dim,
)
class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
@staticmethod
......@@ -647,10 +698,13 @@ class VoyageQwen3BidirectionalEmbedModelConfig(VerifyAndUpdateConfig):
MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"ColBERTJinaRobertaModel": JinaRobertaModelConfig,
"ColQwen3_5": Qwen3_5ForConditionalGenerationConfig,
"DeepseekV32ForCausalLM": DeepseekV32ForCausalLM,
"Ernie4_5_VLMoeForConditionalGeneration": Ernie4_5_VLMoeForConditionalGenerationConfig, # noqa: E501
"FalconMambaForCausalLM": MambaModelConfig,
"Gemma3TextModel": Gemma3TextModelConfig,
"Gemma4ForCausalLM": Gemma4Config,
"Gemma4ForConditionalGeneration": Gemma4Config,
"GptOssForCausalLM": GptOssForCausalLMConfig,
"GteModel": SnowflakeGteNewModelConfig,
"GteNewForSequenceClassification": GteNewModelConfig,
......
This diff is collapsed.
This diff is collapsed.
......@@ -111,6 +111,7 @@ _TEXT_GENERATION_MODELS = {
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
"Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
"Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
"Gemma4ForCausalLM": ("gemma4", "Gemma4ForCausalLM"),
"Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
"Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
......@@ -377,6 +378,7 @@ _MULTIMODAL_MODELS = {
"gemma3n_mm",
"Gemma3nForConditionalGeneration",
),
"Gemma4ForConditionalGeneration": ("gemma4_mm", "Gemma4ForConditionalGeneration"),
"GlmAsrForConditionalGeneration": ("glmasr", "GlmAsrForConditionalGeneration"),
"GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
"Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),
......
......@@ -233,8 +233,15 @@ class AutoWeightsLoader:
):
"""
Add tensor names that are not in the model params that may be in the
safetensors, e.g., batch normalization stats.
safetensors, e.g., batch normalization stats and registered buffers.
"""
# Add persistent registered buffers.
# Non-persistent buffers are excluded, matching PyTorch state_dict().
non_persistent = getattr(module, "_non_persistent_buffers_set", set())
for buf_name, buf in module.named_buffers(recurse=False):
if buf_name not in child_params and buf_name not in non_persistent:
child_params[buf_name] = buf
if isinstance(
module,
(
......
......@@ -32,6 +32,10 @@ _REASONING_PARSERS_TO_REGISTER = {
"ernie45_reasoning_parser",
"Ernie45ReasoningParser",
),
"gemma4": (
"gemma4_reasoning_parser",
"Gemma4ReasoningParser",
),
"glm45": (
"deepseek_v3_reasoning_parser",
"DeepSeekV3ReasoningWithThinkingParser",
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import TYPE_CHECKING
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
from vllm.tokenizers import TokenizerLike
if TYPE_CHECKING:
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
# Role label that Gemma4 emits at the start of the thinking channel.
# The model generates: <|channel>thought\n...reasoning...<channel|>
# This prefix must be stripped to expose only the actual reasoning content.
_THOUGHT_PREFIX = "thought\n"
class Gemma4ReasoningParser(BaseThinkingReasoningParser):
"""
Reasoning parser for Google Gemma4 thinking models.
Gemma4 uses <|channel>...<channel|> tokens to delimit reasoning/thinking
content within its output. Thinking mode is activated by passing
``enable_thinking=True`` in the chat template kwargs, which injects a
system turn containing <|think|> (token 98) to trigger chain-of-thought
reasoning.
Output pattern when thinking is enabled::
<|channel>thought
...chain of thought reasoning...<channel|>
Final answer text here.
The ``thought\\n`` role label inside the channel delimiters is a
structural artefact (analogous to ``user\\n`` in ``<|turn>user\\n...``).
This parser strips it so that downstream consumers see only the
actual reasoning text, consistent with the offline parser
(``vllm.reasoning.gemma4_utils._strip_thought_label``).
"""
def __init__(self, tokenizer: TokenizerLike, *args, **kwargs):
super().__init__(tokenizer, *args, **kwargs)
# Instance state for streaming prefix stripping.
# Tracks only the reasoning text received from the base parser,
# independent of current_text (which may contain pre-reasoning
# content and lacks special token text due to
# skip_special_tokens=True).
self._reasoning_text: str = ""
self._prefix_stripped: bool = False
self.new_turn_token_id = self.vocab["<|turn>"]
self.tool_call_token_id = self.vocab["<|tool_call>"]
self.tool_response_token_id = self.vocab["<|tool_response>"]
def adjust_request(
self, request: "ChatCompletionRequest | ResponsesRequest"
) -> "ChatCompletionRequest | ResponsesRequest":
"""Disable special-token stripping to preserve boundary tokens."""
request.skip_special_tokens = False
return request
@property
def start_token(self) -> str:
"""The token that starts reasoning content."""
return "<|channel>"
@property
def end_token(self) -> str:
"""The token that ends reasoning content."""
return "<channel|>"
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
start_token_id = self.start_token_id
end_token_id = self.end_token_id
new_turn_token_id = self.new_turn_token_id
tool_call_token_id = self.tool_call_token_id
tool_response_token_id = self.tool_response_token_id
# Search from the end of input_ids to find the last match.
for i in range(len(input_ids) - 1, -1, -1):
if input_ids[i] == start_token_id:
return False
if input_ids[i] == tool_call_token_id:
# We're generating a tool call, so reasoning must be ended.
return True
if input_ids[i] in (new_turn_token_id, tool_response_token_id):
# We found a new turn or tool response token so don't consider
# reasoning ended yet, since the model starts new reasoning
# after these tokens.
return False
if input_ids[i] == end_token_id:
return True
return False
# ------------------------------------------------------------------
# Non-streaming path
# ------------------------------------------------------------------
def extract_reasoning(
self,
model_output: str,
request: "ChatCompletionRequest | ResponsesRequest",
) -> tuple[str | None, str | None]:
"""Extract reasoning, stripping the ``thought\\n`` role label."""
if self.start_token not in model_output and self.end_token not in model_output:
# Default to content history if no tags are present
# (or if they were stripped)
return None, model_output
reasoning, content = super().extract_reasoning(model_output, request)
if reasoning is not None:
reasoning = _strip_thought_label(reasoning)
return reasoning, content
# ------------------------------------------------------------------
# Streaming path
# ------------------------------------------------------------------
def extract_reasoning_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> DeltaMessage | None:
"""Extract streaming reasoning, stripping ``thought\\n`` from the
first reasoning delta(s).
The ``thought\\n`` prefix may arrive as a single delta or split
across multiple deltas (e.g. ``"thought"`` then ``"\\n"``). We
buffer early reasoning tokens until we can determine whether the
prefix is present, then emit the buffered content minus the
prefix.
Unlike the previous implementation which reconstructed accumulated
reasoning from ``current_text``, this uses instance state
(``_reasoning_text``) to track only the reasoning content returned
by the base parser. This is necessary because
``skip_special_tokens=True`` (the vLLM default) causes the
``<|channel>`` delimiter to be invisible in ``current_text``,
making it impossible to separate pre-reasoning content from
reasoning content via string matching.
"""
result = super().extract_reasoning_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
delta_token_ids,
)
if result is None:
return None
if result.reasoning is None:
return result
# Accumulate ONLY the reasoning text from base parser results.
# This is immune to pre-reasoning content pollution.
self._reasoning_text += result.reasoning
# Once the prefix has been handled, all subsequent reasoning
# deltas pass through unchanged.
if self._prefix_stripped:
return result
# ---- Prefix stripping logic ----
# Case 1: We've accumulated enough to confirm the prefix is
# present. Strip it and pass through the remainder.
if self._reasoning_text.startswith(_THOUGHT_PREFIX):
prefix_len = len(_THOUGHT_PREFIX)
# How much reasoning was accumulated before this delta?
prev_reasoning_len = len(self._reasoning_text) - len(result.reasoning)
if prev_reasoning_len >= prefix_len:
# Prefix was already consumed by prior deltas; this
# delta is entirely real content — pass through.
self._prefix_stripped = True
return result
else:
# Part or all of the prefix is in this delta.
chars_of_prefix_in_delta = prefix_len - prev_reasoning_len
stripped = result.reasoning[chars_of_prefix_in_delta:]
if stripped:
self._prefix_stripped = True
result.reasoning = stripped
return result
else:
if len(self._reasoning_text) >= prefix_len:
self._prefix_stripped = True
result.reasoning = ""
return result
return None
# Case 2: Accumulated text is a strict prefix of
# _THOUGHT_PREFIX (e.g. we've only seen "thou" so far).
# Buffer by suppressing — we can't yet tell if this will
# become the full prefix or diverge.
if _THOUGHT_PREFIX.startswith(self._reasoning_text):
return None
# Case 3: Accumulated text doesn't match the thought prefix
# at all. This means prior deltas were buffered (suppressed
# by Case 2) but the text diverged. Re-emit the full
# accumulated text to avoid data loss.
self._prefix_stripped = True
result.reasoning = self._reasoning_text
return result
def _strip_thought_label(text: str) -> str:
"""Remove the ``thought\\n`` role label from the beginning of text.
Mirrors ``vllm.reasoning.gemma4_utils._strip_thought_label`` from the
offline parser.
"""
if text.startswith(_THOUGHT_PREFIX):
return text[len(_THOUGHT_PREFIX) :]
return text
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
"""Gemma4 thinking/reasoning output parsing utilities for offline inference.
Standalone functions that parse decoded model text to extract structured
thinking content from Gemma4 models. These are pure-Python utilities with
zero heavy dependencies — they work on raw decoded strings from any
inference backend (vLLM, HuggingFace, TGI, etc.).
For the OpenAI-compatible API reasoning parser (streaming +
non-streaming), see ``vllm.reasoning.gemma4_reasoning_parser``.
For tool call parsing, see ``vllm.tool_parsers.gemma4_utils``.
Usage with vLLM offline inference::
from vllm import LLM, SamplingParams
from vllm.reasoning.gemma4_utils import parse_thinking_output
llm = LLM(model="google/gemma-4-it")
outputs = llm.generate(prompt, SamplingParams(...))
text = tokenizer.decode(outputs[0].outputs[0].token_ids, skip_special_tokens=False)
# Extract thinking / answer (works with or without enable_thinking)
result = parse_thinking_output(text)
print(result["thinking"]) # chain-of-thought or None
print(result["answer"]) # final answer
Ported from ``transformers.models.gemma4.utils_gemma4`` so that vLLM users
do not need a transformers dependency for output parsing.
"""
# ---- Thinking Mode Utility ----
# Thinking delimiter tokens as they appear in decoded text.
# Gemma4 uses <|channel> (start) and <channel|> (end) as thinking delimiters.
_THINKING_START_TAG = "<|channel>"
_THINKING_END_TAG = "<channel|>"
# Sentinel tokens that may appear in decoded output.
_TURN_END_TAG = "<turn|>"
def parse_thinking_output(text: str) -> dict[str, str | None]:
"""Parse decoded Gemma4 model output.
Use this on **all** Gemma4 output regardless of whether thinking mode
was enabled. It handles three cases:
1. **Thinking enabled, tags present** — splits on ``<|channel>``/
``<channel|>`` to separate chain-of-thought from the answer and
strips the ``thought\\n`` role label.
2. **Thinking disabled, spurious label** — strips the bare
``thought\\n`` prefix that some Gemma4 models emit even
without thinking mode.
3. **Clean output** — returns the text unchanged.
The answer text is always cleaned of trailing sentinel tokens
(``<turn|>``, ``<eos>``, etc.).
Args:
text: Decoded model output text (from ``tokenizer.decode(...)``).
Returns:
A dict with keys:
- ``"thinking"``: The chain-of-thought text, or ``None`` if no
thinking delimiters were found.
- ``"answer"``: The final answer text.
Example::
>>> from vllm.reasoning.gemma4_utils import parse_thinking_output
>>> output_text = tokenizer.decode(outputs[0], skip_special_tokens=False)
>>> result = parse_thinking_output(output_text)
>>> print(result["thinking"]) # chain-of-thought reasoning or None
>>> print(result["answer"]) # final answer
"""
if _THINKING_END_TAG in text:
parts = text.split(_THINKING_END_TAG, 1)
thinking_block = parts[0]
answer = _clean_answer(parts[1])
# Extract thinking content: strip the start tag if present
if _THINKING_START_TAG in thinking_block:
thinking = thinking_block.split(_THINKING_START_TAG, 1)[1]
else:
thinking = thinking_block
# Strip the "thought\n" channel role label the model emits inside
# <|channel>thought\n...<channel|> (analogous to "user\n" in
# <|turn>user\n...<turn|>).
thinking = _strip_thought_label(thinking.strip())
thinking = thinking.strip()
return {"thinking": thinking, "answer": answer}
# No thinking delimiters found.
# Strip spurious "thought\n" role label that some Gemma4 models sometimes
# emit even without thinking mode enabled, then clean trailing tokens.
answer = _strip_thought_label(text)
answer = _clean_answer(answer)
return {"thinking": None, "answer": answer}
def _strip_thought_label(text: str) -> str:
"""Strip the spurious ``thought\\n`` label from the start of text.
Only strips when ``thought`` appears as the very first word followed by
a newline — preserving the word ``thought`` in any other context.
"""
if text.startswith("thought\n"):
return text[len("thought\n") :]
return text
def _clean_answer(text: str) -> str:
"""Clean trailing sentinel tokens from the answer text.
Strips ``<turn|>``, ``<eos>``, and surrounding whitespace that the
model appends at the end of its response.
"""
text = text.strip()
# Strip trailing <turn|> (Gemma4 turn-end marker)
if text.endswith(_TURN_END_TAG):
text = text[: -len(_TURN_END_TAG)].rstrip()
# Strip trailing <eos> if present
if text.endswith("<eos>"):
text = text[:-5].rstrip()
return text
......@@ -154,6 +154,10 @@ _TOOL_PARSERS_TO_REGISTER = {
"functiongemma_tool_parser",
"FunctionGemmaToolParser",
),
"gemma4": (
"gemma4_tool_parser",
"Gemma4ToolParser",
),
}
......
This diff is collapsed.
......@@ -300,6 +300,28 @@ class ModelArchConfigConvertorBase:
return model_arch_config
class CohereAsrModelArchConfigConvertor(ModelArchConfigConvertorBase):
def get_total_num_attention_heads(self) -> int:
return self.hf_text_config.transf_decoder["config_dict"]["num_attention_heads"]
def get_head_size(self) -> int:
hidden_size = self.hf_text_config.transf_decoder["config_dict"]["hidden_size"]
num_attention_heads = self.hf_text_config.transf_decoder["config_dict"][
"num_attention_heads"
]
return hidden_size // num_attention_heads
def get_total_num_kv_heads(self) -> int:
enc_num_kv_heads = self.hf_text_config.encoder["n_heads"]
dec_num_kv_heads = self.hf_text_config.transf_decoder["config_dict"][
"num_attention_heads"
]
assert enc_num_kv_heads == dec_num_kv_heads, (
"Encoder and decoder must have the same number of kv heads"
)
return enc_num_kv_heads
class MambaModelArchConfigConvertor(ModelArchConfigConvertorBase):
def get_head_size(self) -> int:
return 0
......@@ -423,6 +445,16 @@ class LongCatFlashMTPModelArchConfigConvertor(ModelArchConfigConvertorBase):
return getattr(self.hf_text_config, "num_nextn_predict_layers", 1)
class Gemma4ModelArchConfigConvertor(ModelArchConfigConvertorBase):
def get_head_size(self) -> int:
# Gemma4 uses dual head dimensions: head_dim (sliding attention)
# and global_head_dim (full attention). Return the largest so
# that attention backends allocate buffers large enough for both.
head_dim = getattr(self.hf_text_config, "head_dim", 0)
global_head_dim = getattr(self.hf_text_config, "global_head_dim", 0)
return max(head_dim, global_head_dim) or super().get_head_size()
# hf_config.model_type -> convertor class
MODEL_ARCH_CONFIG_CONVERTORS = {
"mamba": MambaModelArchConfigConvertor,
......@@ -433,6 +465,8 @@ MODEL_ARCH_CONFIG_CONVERTORS = {
"mpt": MPTModelArchConfigConvertor,
"dbrx": DbrxModelArchConfigConvertor,
"falcon": FalconModelArchConfigConvertor,
"gemma4": Gemma4ModelArchConfigConvertor,
"gemma4_text": Gemma4ModelArchConfigConvertor,
"RefinedWeb": FalconModelArchConfigConvertor,
"RefinedWebModel": FalconModelArchConfigConvertor,
"nemotron-nas": NemotronNasModelArchConfigConvertor,
......
......@@ -1040,6 +1040,8 @@ def unified_attention(
num_seqs=num_seqs,
BLOCK_M=BLOCK_M,
USE_FP8=output_scale is not None,
num_stages =1
)
else:
kernel_unified_attention_3d[
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment