Unverified Commit a0e1cf99 authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Fix for the RMSNorm tests/doc/ONNX export to match the actual implementation (#364)



Fix for the RMSNorm tests/doc/ONNX export to match the actual
implementation
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
parent 66ff2e36
...@@ -333,13 +333,14 @@ class TorchRMSNorm(nn.Module): ...@@ -333,13 +333,14 @@ class TorchRMSNorm(nn.Module):
self.register_parameter("weight", self.weight) self.register_parameter("weight", self.weight)
def forward(self, x): def forward(self, x):
norm_x = x.norm(2, dim=-1, keepdim=True) norm_x2 = torch.sum(x.float()**2, dim=-1, keepdim=True)
d_x = self.in_features d_x = self.in_features
rms_x = norm_x * d_x ** (-1. / 2) rms_x2 = norm_x2 / d_x + self.eps
x_normed = x / (rms_x + self.eps) r_rms_x = rms_x2 ** (-1. / 2)
x_normed = x * r_rms_x
return self.weight * x_normed return (self.weight.float() * x_normed).to(x.dtype)
class TorchLayerNormLinear(nn.Module): class TorchLayerNormLinear(nn.Module):
def __init__(self, in_features: int, out_features: int, def __init__(self, in_features: int, out_features: int,
...@@ -877,12 +878,14 @@ def test_linear_accuracy(dtype, bs, model): ...@@ -877,12 +878,14 @@ def test_linear_accuracy(dtype, bs, model):
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
def test_rmsnorm_accuracy(dtype, bs, model): @pytest.mark.parametrize("eps", [1e-1, 1e-3, 1e-5, 1e-7])
def test_rmsnorm_accuracy(dtype, bs, model, eps):
config = model_configs[model] config = model_configs[model]
te_rmsnorm = ( te_rmsnorm = (
RMSNorm( RMSNorm(
config.hidden_size, config.hidden_size,
eps=eps,
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
...@@ -892,6 +895,7 @@ def test_rmsnorm_accuracy(dtype, bs, model): ...@@ -892,6 +895,7 @@ def test_rmsnorm_accuracy(dtype, bs, model):
torch_rmsnorm = ( torch_rmsnorm = (
TorchRMSNorm( TorchRMSNorm(
config.hidden_size, config.hidden_size,
eps=eps,
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
......
...@@ -79,12 +79,12 @@ class RMSNorm(torch.nn.Module): ...@@ -79,12 +79,12 @@ class RMSNorm(torch.nn.Module):
the paper `Root Mean Square Layer Normalization <https://arxiv.org/abs/1910.07467>`__ the paper `Root Mean Square Layer Normalization <https://arxiv.org/abs/1910.07467>`__
.. math:: .. math::
y = \frac{x}{RMS(x) + \varepsilon} * \gamma y = \frac{x}{RMS_\varepsilon(x)} * \gamma
where where
.. math:: .. math::
RMS(x) = \sqrt{\frac{1}{n}\sum_{i=0}^nx_i^2} RMS_\varepsilon(x) = \sqrt{\frac{1}{n}\sum_{i=0}^nx_i^2 + \varepsilon}
:math:`\gamma` is a learnable affine transform parameter of size :attr:`hidden_size` :math:`\gamma` is a learnable affine transform parameter of size :attr:`hidden_size`
......
...@@ -382,14 +382,13 @@ def onnx_rmsnorm_fwd(g, inputs, weight, eps, zero_centered_gamma): ...@@ -382,14 +382,13 @@ def onnx_rmsnorm_fwd(g, inputs, weight, eps, zero_centered_gamma):
inputs_float = g.op("Cast", inputs, to_i=_C_onnx.TensorProtoDataType.FLOAT) inputs_float = g.op("Cast", inputs, to_i=_C_onnx.TensorProtoDataType.FLOAT)
norm = g.op("ReduceL2", inputs_float, axes_i=[axis]) sum_square = g.op("ReduceSumSquare", inputs_float, axes_i=[axis])
shape = g.op("Shape", inputs_float, start_i=-1) shape = g.op("Shape", inputs_float, start_i=-1)
shape_f = g.op("Cast", shape, to_i=_C_onnx.TensorProtoDataType.FLOAT) shape_f = g.op("Cast", shape, to_i=_C_onnx.TensorProtoDataType.FLOAT)
n_reciprocal = g.op("Reciprocal", shape_f) mean_squared = g.op("Div", sum_square, shape_f)
sqrt_n_reciprocal = g.op("Sqrt", n_reciprocal)
rms = g.op("Mul", norm, sqrt_n_reciprocal)
eps_tensor = g.op("ConstantOfShape", shape, value_t=torch.tensor([eps], dtype=torch.float32)) eps_tensor = g.op("ConstantOfShape", shape, value_t=torch.tensor([eps], dtype=torch.float32))
rms_eps = g.op("Add", rms, eps_tensor) rms_squared = g.op("Add", mean_squared, eps_tensor)
rms_eps = g.op("Sqrt", rms_squared)
normalized_input = g.op("Div", inputs_float, rms_eps) normalized_input = g.op("Div", inputs_float, rms_eps)
result = g.op("Mul", weight, normalized_input) result = g.op("Mul", weight, normalized_input)
result = g.op("Cast", result, to_i=get_TensorProtoDataType(inputs)) result = g.op("Cast", result, to_i=get_TensorProtoDataType(inputs))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment