Unverified Commit 221bf725 authored by Jianyu Huang's avatar Jianyu Huang Committed by GitHub
Browse files

output type conversion fix (#27159)

parent b3aba04e
...@@ -134,10 +134,7 @@ def matmul_kernel_persistent( ...@@ -134,10 +134,7 @@ def matmul_kernel_persistent(
bias_ptrs = bias_ptr + offs_cn bias_ptrs = bias_ptr + offs_cn
bias = tl.load(bias_ptrs, mask=offs_cn < N, other=0.0).to(tl.float32) bias = tl.load(bias_ptrs, mask=offs_cn < N, other=0.0).to(tl.float32)
accumulator += bias accumulator += bias
if c_ptr.dtype.element_ty == tl.float8e4nv: c = accumulator.to(c_ptr.dtype.element_ty)
c = accumulator.to(tl.float8e4nv)
else:
c = accumulator.to(tl.float16)
tl.store(c_ptrs, c, mask=c_mask) tl.store(c_ptrs, c, mask=c_mask)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment