Unverified Commit 5898702e authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Support arbitrary output dtypes in PyT GEMM functions (#75)



* Deprecate fp32_output option for PyT linear layers

Automatically detect dtype for user-provided output tensors.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove deprecated options
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 5c7c6016
...@@ -371,7 +371,6 @@ def test_export_gemm( ...@@ -371,7 +371,6 @@ def test_export_gemm(
get_workspace(), get_workspace(),
bias=self.bias, bias=self.bias,
use_bias=self.use_bias, use_bias=self.use_bias,
fp32_output=(self.precision==torch.float32),
use_split_accumulator=False) use_split_accumulator=False)
return ret return ret
......
...@@ -26,9 +26,8 @@ def fp8_gemm( ...@@ -26,9 +26,8 @@ def fp8_gemm(
fp8_meta_tensor: tex.FP8TensorMeta = None, fp8_meta_tensor: tex.FP8TensorMeta = None,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
use_bias: bool = False, use_bias: bool = False,
fp32_output: bool = False,
use_split_accumulator: bool = False, use_split_accumulator: bool = False,
D_dtype: tex.DType = None, D_dtype: Optional[tex.DType] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""TN layout GEMM with fp8 inputs.""" """TN layout GEMM with fp8 inputs."""
...@@ -41,15 +40,14 @@ def fp8_gemm( ...@@ -41,15 +40,14 @@ def fp8_gemm(
out = torch.empty( out = torch.empty(
B.shape[0], B.shape[0],
A.shape[0], A.shape[0],
dtype=torch.float32 if fp32_output else out_dtype, dtype=out_dtype,
device="cuda", device="cuda",
) )
return_output = True return_output = True
out_dtype = tex.DType.kFloat32 if fp32_output else TE_DType[out_dtype] out_dtype = TE_DType[out.dtype] if D_dtype is None else D_dtype
# Use bfloat16 as default bias_dtype # Use bfloat16 as default bias_dtype
bias_dtype = tex.DType.kBFloat16 if bias is None else TE_DType[bias.dtype] bias_dtype = tex.DType.kBFloat16 if bias is None else TE_DType[bias.dtype]
out_dtype = D_dtype if D_dtype is not None else out_dtype
_ = torch.ops.tex_ts.te_gemm_ts( _ = torch.ops.tex_ts.te_gemm_ts(
A, A,
...@@ -94,7 +92,6 @@ def gemm( ...@@ -94,7 +92,6 @@ def gemm(
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
use_bias: bool = False, use_bias: bool = False,
fp32_output: bool = False,
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
"""Non FP8 GEMM.""" """Non FP8 GEMM."""
...@@ -104,16 +101,12 @@ def gemm( ...@@ -104,16 +101,12 @@ def gemm(
empty_tensor = torch.Tensor() empty_tensor = torch.Tensor()
fp8_index = -1 # dummy index fp8_index = -1 # dummy index
input_dtype = TE_DType[dtype]
output_dtype = tex.DType.kFloat32 if fp32_output else input_dtype
bias_dtype = output_dtype if bias is None else TE_DType[bias.dtype]
return_output = False return_output = False
if out is None: if out is None:
out = torch.empty( out = torch.empty(
B.shape[1] if transb else B.shape[0], B.shape[1] if transb else B.shape[0],
A.shape[0] if transa else A.shape[1], A.shape[0] if transa else A.shape[1],
dtype=torch.float32 if fp32_output else dtype, dtype=dtype,
device="cuda", device="cuda",
) )
return_output = True return_output = True
...@@ -124,14 +117,21 @@ def gemm( ...@@ -124,14 +117,21 @@ def gemm(
gelu_input = empty_tensor gelu_input = empty_tensor
if grad and use_bias: if grad and use_bias:
grad_bias = torch.empty( grad_bias = torch.empty(B.shape[1], dtype=out.dtype, device="cuda")
B.shape[1], dtype=torch.float32 if fp32_output else dtype, device="cuda"
)
else: else:
grad_bias = empty_tensor grad_bias = empty_tensor
bias = bias if use_bias else empty_tensor bias = bias if use_bias else empty_tensor
assert A.dtype == dtype and B.dtype == dtype, \
f'Expected dtype={dtype}, but found A.dtype={A.dtype} and B.dtype={B.dtype}'
input_dtype = TE_DType[dtype]
output_dtype = TE_DType[out.dtype]
if use_bias:
bias_dtype = TE_DType[grad_bias.dtype] if grad else TE_DType[bias.dtype]
else:
bias_dtype = output_dtype
_ = torch.ops.tex_ts.te_gemm_ts( _ = torch.ops.tex_ts.te_gemm_ts(
A, A,
empty_tensor, empty_tensor,
......
...@@ -960,7 +960,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -960,7 +960,6 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.activation_dtype, ctx.activation_dtype,
get_workspace(), get_workspace(),
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
use_split_accumulator=_2X_ACC_WGRAD, use_split_accumulator=_2X_ACC_WGRAD,
) )
...@@ -980,7 +979,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -980,7 +979,6 @@ class _LayerNormLinear(torch.autograd.Function):
layout="NT", layout="NT",
grad=True, grad=True,
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
) )
else: else:
...@@ -994,7 +992,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -994,7 +992,6 @@ class _LayerNormLinear(torch.autograd.Function):
grad=True, grad=True,
use_bias=ctx.use_bias, use_bias=ctx.use_bias,
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
) )
...@@ -1652,7 +1649,6 @@ class _Linear(torch.autograd.Function): ...@@ -1652,7 +1649,6 @@ class _Linear(torch.autograd.Function):
ctx.activation_dtype, ctx.activation_dtype,
get_workspace(), get_workspace(),
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
use_split_accumulator=_2X_ACC_WGRAD, use_split_accumulator=_2X_ACC_WGRAD,
) )
...@@ -1665,7 +1661,6 @@ class _Linear(torch.autograd.Function): ...@@ -1665,7 +1661,6 @@ class _Linear(torch.autograd.Function):
layout="NT", layout="NT",
grad=True, grad=True,
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
) )
else: else:
...@@ -1679,7 +1674,6 @@ class _Linear(torch.autograd.Function): ...@@ -1679,7 +1674,6 @@ class _Linear(torch.autograd.Function):
grad=True, grad=True,
use_bias=ctx.use_bias, use_bias=ctx.use_bias,
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
) )
...@@ -2358,7 +2352,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2358,7 +2352,6 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.activation_dtype, ctx.activation_dtype,
get_workspace(), get_workspace(),
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=fc2_weight.main_grad out=fc2_weight.main_grad
if ctx.fuse_wgrad_accumulation if ctx.fuse_wgrad_accumulation
else None, else None,
...@@ -2390,7 +2383,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2390,7 +2383,6 @@ class _LayerNormMLP(torch.autograd.Function):
grad=True, grad=True,
use_bias=ctx.use_bias, use_bias=ctx.use_bias,
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=fc2_weight.main_grad out=fc2_weight.main_grad
if ctx.fuse_wgrad_accumulation if ctx.fuse_wgrad_accumulation
else None, else None,
...@@ -2446,7 +2438,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2446,7 +2438,6 @@ class _LayerNormMLP(torch.autograd.Function):
grad=True, grad=True,
use_bias=ctx.use_bias, use_bias=ctx.use_bias,
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
) )
...@@ -2491,7 +2482,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2491,7 +2482,6 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.activation_dtype, ctx.activation_dtype,
get_workspace(), get_workspace(),
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=fc1_weight.main_grad out=fc1_weight.main_grad
if ctx.fuse_wgrad_accumulation if ctx.fuse_wgrad_accumulation
else None, else None,
...@@ -2513,7 +2503,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2513,7 +2503,6 @@ class _LayerNormMLP(torch.autograd.Function):
layout="NT", layout="NT",
grad=True, grad=True,
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=fc1_weight.main_grad out=fc1_weight.main_grad
if ctx.fuse_wgrad_accumulation if ctx.fuse_wgrad_accumulation
else None, else None,
...@@ -2529,7 +2518,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2529,7 +2518,6 @@ class _LayerNormMLP(torch.autograd.Function):
grad=True, grad=True,
use_bias=not ctx.bias_gelu_nvfusion, use_bias=not ctx.bias_gelu_nvfusion,
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
fp32_output=ctx.fuse_wgrad_accumulation,
out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment