Unverified Commit eb1051fb authored by Ye (Charlotte) Qi's avatar Ye (Charlotte) Qi Committed by GitHub
Browse files

[ROCm] Guard group quant RMS norm fusion patterns (#30239)

parent 80433e22
......@@ -490,6 +490,8 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
# as the latter is a subset of the former in torch ops
for epsilon in [1e-5, 1e-6]:
# Fuse fused_add_rms_norm + fp8 group quant
# Only register group quant patterns on CUDA where the C++ op exists
if current_platform.is_cuda():
FusedAddRMSNormGroupQuantPattern(
epsilon, FP8_DTYPE, group_shape=GroupShape(1, 128)
).register(self.patterns)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment