Commit 5dc2dc0d authored by zhuwenwen's avatar zhuwenwen
Browse files

Update layernorm.py

parent 50281e36
......@@ -253,7 +253,8 @@ class RMSNorm(CustomOp):
return self.rocm_norm_func_with_add(x, residual, self.weight.data,
self.variance_epsilon)
else:
return norm_func(x, self.weight.data, self.variance_epsilon)
return self.rocm_norm_func(x, self.weight.data,
self.variance_epsilon)
def forward_apex(
self,
......@@ -265,7 +266,7 @@ class RMSNorm(CustomOp):
norm_func = dispatch_cuda_rmsnorm_func(add_residual)
if add_residual:
return norm_func(x, residual, self.weight.data,
return self.rocm_norm_func_with_add(x, residual, self.weight.data,
self.variance_epsilon)
else:
return fused_rms_norm_affine(x, self.weight.data, torch.Size((x.shape[-1],)), self.variance_epsilon)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment