Unverified Commit 7c3f88b2 authored by Thomas Parnell's avatar Thomas Parnell Committed by GitHub
Browse files

[Bugfix] Remove false-positive format mismatch warnings in FLA ops (#38255)


Signed-off-by: default avatarThomas Parnell <tpa@zurich.ibm.com>
parent 6557f493
......@@ -7,7 +7,6 @@
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
import warnings
import torch
......@@ -184,13 +183,6 @@ def chunk_gated_delta_rule(
"ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
)
assert len(beta.shape) == 3, "beta must be of shape [B, T, H]."
if q.shape[1] < q.shape[2]:
warnings.warn(
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
"Please verify your input tensor format matches the expected shape [B, T, H, ...].",
stacklevel=2,
)
if cu_seqlens is not None:
if q.shape[0] != 1:
raise ValueError(
......
......@@ -7,7 +7,6 @@
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
import warnings
import torch
......@@ -252,14 +251,6 @@ def chunk_local_cumsum(
output_dtype: torch.dtype | None = torch.float,
**kwargs,
) -> torch.Tensor:
if not head_first and g.shape[1] < g.shape[2]:
warnings.warn(
f"Input tensor shape suggests potential format mismatch: seq_len ({g.shape[1]}) < num_heads ({g.shape[2]}). "
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
"when head_first=False was specified. "
"Please verify your input tensor format matches the expected shape [B, T, H, ...].",
stacklevel=2,
)
if cu_seqlens is not None:
assert g.shape[0] == 1, (
"Only batch size 1 is supported when cu_seqlens are provided"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment