out_ref=rearrange(F.layer_norm(rearrange(x_ref,"... (g d) -> ... g d",d=group_size).float(),(group_size,),eps=1e-5),"... g d -> ... (g d)")*weight_ref.float()
ifhas_bias:
out_ref=out_ref+bias_ref.float()
out_pt=rearrange(F.layer_norm(rearrange(x_pt,"... (g d) -> ... g d",d=group_size),(group_size,),eps=1e-5),"... g d -> ... (g d)")*weight_pt