Commit c334b741 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.15.1-dev_2.28' into 'v0.15.1-dev'

修复qwen3.5的使用dtype为fp16的picecwise的推理模式

See merge request dcutoolkit/deeplearing/vllm!448
parents 4262c4d9 b5e8d01e
...@@ -458,6 +458,10 @@ class RMSNormGated(CustomOp): ...@@ -458,6 +458,10 @@ class RMSNormGated(CustomOp):
- norm_before_gate=True: out = norm(x) * silu(z) - norm_before_gate=True: out = norm(x) * silu(z)
- norm_before_gate=False: out = norm(x * silu(z)) - norm_before_gate=False: out = norm(x * silu(z))
""" """
orig_dtype = x.dtype
x = x.float()
weight = self.weight.float()
z = z.float() if z is not None else None
# Apply gating before normalization if needed # Apply gating before normalization if needed
if z is not None and not self.norm_before_gate: if z is not None and not self.norm_before_gate:
x = x * F.silu(z) x = x * F.silu(z)
...@@ -467,7 +471,7 @@ class RMSNormGated(CustomOp): ...@@ -467,7 +471,7 @@ class RMSNormGated(CustomOp):
# Standard RMS norm across the last dimension # Standard RMS norm across the last dimension
variance = x.pow(2).mean(dim=-1, keepdim=True) variance = x.pow(2).mean(dim=-1, keepdim=True)
x_normed = x * torch.rsqrt(variance + self.eps) x_normed = x * torch.rsqrt(variance + self.eps)
out = x_normed * self.weight out = x_normed * weight
else: else:
# Group RMS norm # Group RMS norm
from einops import rearrange from einops import rearrange
...@@ -475,13 +479,13 @@ class RMSNormGated(CustomOp): ...@@ -475,13 +479,13 @@ class RMSNormGated(CustomOp):
x_group = rearrange(x, "... (g d) -> ... g d", d=self.group_size) x_group = rearrange(x, "... (g d) -> ... g d", d=self.group_size)
variance = x_group.pow(2).mean(dim=-1, keepdim=True) variance = x_group.pow(2).mean(dim=-1, keepdim=True)
x_normed = x_group * torch.rsqrt(variance + self.eps) x_normed = x_group * torch.rsqrt(variance + self.eps)
out = rearrange(x_normed, "... g d -> ... (g d)") * self.weight out = rearrange(x_normed, "... g d -> ... (g d)") * weight
# Apply gating after normalization if needed # Apply gating after normalization if needed
if z is not None and self.norm_before_gate: if z is not None and self.norm_before_gate:
out = out * F.silu(z) out = out * F.silu(z)
return out return out.to(orig_dtype)
def forward_cuda( def forward_cuda(
self, x: torch.Tensor, z: torch.Tensor | None = None self, x: torch.Tensor, z: torch.Tensor | None = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment