Unverified Commit 1afb6256 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Fix incorrect dtype in LayerNormLinear (#483)


Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d097883e
...@@ -112,7 +112,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -112,7 +112,7 @@ class _LayerNormLinear(torch.autograd.Function):
ub_obj_lnout = get_ub("qkv_fprop") ub_obj_lnout = get_ub("qkv_fprop")
ln_out = ub_obj_lnout.get_ubuf_output(0) ln_out = ub_obj_lnout.get_ubuf_output(0)
else: else:
ln_out_dtype = torch.uint8 if fp8 else inputmat.dtype ln_out_dtype = torch.uint8 if (fp8 and not return_layernorm_output) else inputmat.dtype
ln_out = torch.empty_like(inputmat, dtype=ln_out_dtype) ln_out = torch.empty_like(inputmat, dtype=ln_out_dtype)
if ub_atomic_gemm_ag: if ub_atomic_gemm_ag:
assert fp8, "AtomicGemm overlap supported only for FP8 GEMM." assert fp8, "AtomicGemm overlap supported only for FP8 GEMM."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment