Unverified Commit 8a34c508 authored by Andrew Barnes's avatar Andrew Barnes Committed by GitHub
Browse files

[ROCm] Remove unnecessary fp8 roundtrip in gather cache NHD dequant (#39122)


Signed-off-by: default avatarBortlesboat <bortstheboat@gmail.com>
parent ed2f282b
......@@ -112,10 +112,12 @@ if current_platform.is_rocm():
if DEQUANT:
k_scale = tl.load(k_scale_ptr)
v_scale = tl.load(v_scale_ptr)
k_dtype = k_reg.dtype
v_dtype = v_reg.dtype
k_reg = (k_reg.to(tl.float32) * k_scale).to(k_dtype)
v_reg = (v_reg.to(tl.float32) * v_scale).to(v_dtype)
k_reg = (k_reg.to(tl.float32) * k_scale).to(
key_ptr_offset.dtype.element_ty
)
v_reg = (v_reg.to(tl.float32) * v_scale).to(
value_ptr_offset.dtype.element_ty
)
tl.store(key_ptr_offset + col_offsets, k_reg)
tl.store(value_ptr_offset + col_offsets, v_reg)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment