Unverified Commit b8665383 authored by Gregory Shtrasberg's avatar Gregory Shtrasberg Committed by GitHub
Browse files

[ROCm] Fix GPT-OSS import for triton 3.6 (#37453)


Signed-off-by: default avatarGregory Shtrasberg <Gregory.Shtrasberg@amd.com>
parent 0e9358c1
...@@ -48,9 +48,16 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps=8): ...@@ -48,9 +48,16 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps=8):
value_layout = StridedLayout value_layout = StridedLayout
if on_gfx950(): if on_gfx950():
from triton_kernels.tensor_details.layout import GFX950MXScaleLayout try:
# triton < 3.6
from triton_kernels.tensor_details.layout import GFX950MXScaleLayout
scale_layout = GFX950MXScaleLayout scale_layout = GFX950MXScaleLayout
except ImportError:
# triton >= 3.6
from triton_kernels.tensor_details.layout import CDNA4MXScaleLayout
scale_layout = CDNA4MXScaleLayout
else: else:
scale_layout = StridedLayout scale_layout = StridedLayout
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment