Unverified Commit a810b5b0 authored by TJian's avatar TJian Committed by GitHub
Browse files

[BugFix] [ROCm]: Bugfix and handle addition case of input for `rocm_aiter_rms_norm` (#17857)


Signed-off-by: default avatartjtanaa <tunjian.tan@embeddedllm.com>
parent 009b3d53
......@@ -28,6 +28,7 @@ AITER_MODEL_LIST = [
"Qwen/Qwen-7B-Chat",
"Qwen/Qwen2.5-0.5B-Instruct",
"TitanML/tiny-mixtral",
"Qwen/Qwen3-8B",
]
......@@ -78,6 +79,9 @@ AITER_MODEL_LIST = [
"Qwen/Qwen2.5-0.5B-Instruct", # qwen2
marks=[pytest.mark.core_model],
),
pytest.param(
"Qwen/Qwen3-8B", # qwen (text-only)
),
pytest.param("stabilityai/stablelm-3b-4e1t"), # stablelm
pytest.param("bigcode/starcoder2-3b"), # starcoder2
pytest.param(
......
......@@ -46,6 +46,12 @@ def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> torch.Tensor:
import aiter as rocm_aiter
if x.dim() > 2:
x_original_shape = x.shape
x = x.reshape(-1, x_original_shape[-1])
x = rocm_aiter.rms_norm(x, weight, variance_epsilon)
return x.reshape(x_original_shape)
return rocm_aiter.rms_norm(x, weight, variance_epsilon)
......@@ -55,16 +61,17 @@ def rocm_aiter_fused_add_rms_norm(
import aiter as rocm_aiter
# Assuming the correct signature for rmsnorm2d_fwd_with_add
residual_out = torch.empty_like(residual)
output = torch.empty_like(x)
rocm_aiter.rmsnorm2d_fwd_with_add(
x, # output
output, # output
x, # input
residual, # residual input
residual, # residual output
residual_out, # residual output
weight,
variance_epsilon,
)
return x, residual
return output, residual_out
def dispatch_cuda_rmsnorm_func(add_residual: bool):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment