"tests/entrypoints/pooling/openai/test_classification.py" did not exist on "370661856bcfc4cdc9a88580cb70d66b7ac9fc7c"
Unverified Commit 20bd6f4d authored by Dhia Eddine Rhaiem's avatar Dhia Eddine Rhaiem Committed by GitHub
Browse files

[FalconH1] Fix output dtype in RMSNorm fallback path for Falcon-H1 (e.g. 0.5B) (#18500)


Signed-off-by: default avatardhia.rhaiem <dhia.rhaiem@tii.ae>
Co-authored-by: default avataryounesbelkada <younesbelkada@gmail.com>
Co-authored-by: default avatarIlyas Chahed <ilyas.chahed@tii.ae>
Co-authored-by: default avatarJingwei Zuo <jingwei.zuo@tii.ae>
parent 1f079540
......@@ -77,7 +77,7 @@ class Mixer2RMSNormGated(CustomOp):
input_dtype = x.dtype
x = x * nn.functional.silu(gate.to(torch.float32))
if not self.use_rms_norm:
return x
return x.to(input_dtype)
if self.n_groups == 1:
if self.tp_size > 1:
......@@ -117,9 +117,11 @@ class Mixer2RMSNormGated(CustomOp):
x: torch.Tensor,
gate: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
input_dtype = x.dtype
if not self.use_rms_norm:
return x * nn.functional.silu(gate.to(torch.float32))
# Keep gate in float32 for numerical stability during silu
return x * nn.functional.silu(gate.to(
torch.float32)).to(input_dtype)
if self.tp_size > 1 or self.n_groups != 1:
return self.forward_native(x, gate)
......
......@@ -453,7 +453,6 @@ class FalconH1Model(nn.Module):
attn_metadata = get_forward_context().attn_metadata
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.mamba_chunk_size,
input_ids=input_ids,
attn_metadata=attn_metadata,
)
if get_pp_group().is_first_rank:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment