Unverified Commit 7c3f88b2 authored by Thomas Parnell's avatar Thomas Parnell Committed by GitHub
Browse files

[Bugfix] Remove false-positive format mismatch warnings in FLA ops (#38255)


Signed-off-by: default avatarThomas Parnell <tpa@zurich.ibm.com>
parent 6557f493
...@@ -7,7 +7,6 @@ ...@@ -7,7 +7,6 @@
# the following copyright notice: # the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501 # ruff: noqa: E501
import warnings
import torch import torch
...@@ -184,13 +183,6 @@ def chunk_gated_delta_rule( ...@@ -184,13 +183,6 @@ def chunk_gated_delta_rule(
"ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16." "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
) )
assert len(beta.shape) == 3, "beta must be of shape [B, T, H]." assert len(beta.shape) == 3, "beta must be of shape [B, T, H]."
if q.shape[1] < q.shape[2]:
warnings.warn(
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
"Please verify your input tensor format matches the expected shape [B, T, H, ...].",
stacklevel=2,
)
if cu_seqlens is not None: if cu_seqlens is not None:
if q.shape[0] != 1: if q.shape[0] != 1:
raise ValueError( raise ValueError(
......
...@@ -7,7 +7,6 @@ ...@@ -7,7 +7,6 @@
# the following copyright notice: # the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501 # ruff: noqa: E501
import warnings
import torch import torch
...@@ -252,14 +251,6 @@ def chunk_local_cumsum( ...@@ -252,14 +251,6 @@ def chunk_local_cumsum(
output_dtype: torch.dtype | None = torch.float, output_dtype: torch.dtype | None = torch.float,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
if not head_first and g.shape[1] < g.shape[2]:
warnings.warn(
f"Input tensor shape suggests potential format mismatch: seq_len ({g.shape[1]}) < num_heads ({g.shape[2]}). "
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
"when head_first=False was specified. "
"Please verify your input tensor format matches the expected shape [B, T, H, ...].",
stacklevel=2,
)
if cu_seqlens is not None: if cu_seqlens is not None:
assert g.shape[0] == 1, ( assert g.shape[0] == 1, (
"Only batch size 1 is supported when cu_seqlens are provided" "Only batch size 1 is supported when cu_seqlens are provided"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment