Unverified Commit 1f36c2c9 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Fix linter warnings from unused args (#816)



* Fix linter warnings from unused args
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Update .gitignore
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 071b9508
...@@ -1757,11 +1757,12 @@ class _PrepareQKVForFA(torch.autograd.Function): ...@@ -1757,11 +1757,12 @@ class _PrepareQKVForFA(torch.autograd.Function):
to separate contiguous q, k, v tensors in (b, s, ...) layout.""" to separate contiguous q, k, v tensors in (b, s, ...) layout."""
@staticmethod @staticmethod
def forward(ctx, def forward(
_ctx: torch.autograd.function.FunctionCtx, # unused
query_layer: torch.Tensor, query_layer: torch.Tensor,
key_layer: torch.Tensor, key_layer: torch.Tensor,
value_layer: torch.Tensor value_layer: torch.Tensor
) -> torch.Tensor: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# All inputs received are non-contiguous tensors. # All inputs received are non-contiguous tensors.
# The `query_layer` tensor is used to access the # The `query_layer` tensor is used to access the
# full memory region of the QKV tensor. # full memory region of the QKV tensor.
...@@ -1773,7 +1774,8 @@ class _PrepareQKVForFA(torch.autograd.Function): ...@@ -1773,7 +1774,8 @@ class _PrepareQKVForFA(torch.autograd.Function):
return query_layer, key_layer, value_layer return query_layer, key_layer, value_layer
@staticmethod @staticmethod
def backward(ctx, def backward(
_ctx: torch.autograd.function.FunctionCtx, # unused
dq: torch.Tensor, dq: torch.Tensor,
dk: torch.Tensor, dk: torch.Tensor,
dv: torch.Tensor dv: torch.Tensor
......
...@@ -46,7 +46,7 @@ class _FromFloat8Func(torch.autograd.Function): ...@@ -46,7 +46,7 @@ class _FromFloat8Func(torch.autograd.Function):
"""Cast from FP8 to other dtype""" """Cast from FP8 to other dtype"""
@staticmethod @staticmethod
def forward( def forward(
ctx, _ctx: torch.autograd.function.FunctionCtx, # unused
tensor: Float8Tensor, tensor: Float8Tensor,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -63,7 +63,10 @@ class _FromFloat8Func(torch.autograd.Function): ...@@ -63,7 +63,10 @@ class _FromFloat8Func(torch.autograd.Function):
return out return out
@staticmethod @staticmethod
def backward(ctx, grad): def backward(
_ctx: torch.autograd.function.FunctionCtx, # unused
grad: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]:
# Assume that we want gradients in full precision # Assume that we want gradients in full precision
return grad, None return grad, None
...@@ -97,7 +100,7 @@ class _ToFloat8Func(torch.autograd.Function): ...@@ -97,7 +100,7 @@ class _ToFloat8Func(torch.autograd.Function):
"""Cast to FP8 from other dtype""" """Cast to FP8 from other dtype"""
@staticmethod @staticmethod
def forward( def forward(
ctx, _ctx: torch.autograd.function.FunctionCtx, # unused
tensor: torch.Tensor, tensor: torch.Tensor,
fp8_meta: Optional[Dict[str, Any]] = None, fp8_meta: Optional[Dict[str, Any]] = None,
fp8_meta_forward: bool = True, fp8_meta_forward: bool = True,
...@@ -106,7 +109,7 @@ class _ToFloat8Func(torch.autograd.Function): ...@@ -106,7 +109,7 @@ class _ToFloat8Func(torch.autograd.Function):
scale: Optional[torch.Tensor] = None, scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None, amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None, scale_inv: Optional[torch.Tensor] = None,
): ) -> Float8Tensor:
# Manually compute scale-inverse if needed # Manually compute scale-inverse if needed
if scale is not None and scale_inv is None: if scale is not None and scale_inv is None:
...@@ -189,7 +192,10 @@ class _ToFloat8Func(torch.autograd.Function): ...@@ -189,7 +192,10 @@ class _ToFloat8Func(torch.autograd.Function):
) )
@staticmethod @staticmethod
def backward(ctx, grad): def backward(
_ctx: torch.autograd.function.FunctionCtx, # unused
grad: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]:
# Assume that we want gradients in full precision # Assume that we want gradients in full precision
return grad, None, None, None, None, None, None, None return grad, None, None, None, None, None, None, None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment