Unverified Commit d99142a0 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Add auto-formatter (#919)



* Initial config test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* remove linters, fix clang-format
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix clang-format
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix clang-format
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Remove lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Adjust config
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* use config file
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* adjust pylintrc
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* pre-format fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Python only
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add FA module
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Update CI configs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* CRLF -> LF
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* format
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* revert accidental formatting changes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* try with sudo
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* cpp formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix pylint error properly
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* some review comments
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* lint fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* add fp8 attn include in the correct file
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* autofix PRs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 43569381
......@@ -175,21 +175,24 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
// wd_after_momentum);
// }
// Case 2. fp32, fp32, fp32, No
else if (grad_type == at::ScalarType::Float && weight_type == at::ScalarType::Float && // NOLINT(*)
else if (grad_type == at::ScalarType::Float && // NOLINT(*)
weight_type == at::ScalarType::Float &&
num_tensors == 3) {
multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
SGDFunctor<3, float, float>(), wd, momentum, dampening, lr, nesterov,
first_run, wd_after_momentum, scale);
}
// Case 3. fp16, fp32, fp32, Yes
else if (grad_type == at::ScalarType::Half && weight_type == at::ScalarType::Float && // NOLINT(*)
else if (grad_type == at::ScalarType::Half && // NOLINT(*)
weight_type == at::ScalarType::Float &&
num_tensors == 4) {
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
SGDFunctor<4, at::Half, float>(), wd, momentum, dampening, lr, nesterov,
first_run, wd_after_momentum, scale);
}
// Case 4. fp32, fp32, fp32, Yes
else if (grad_type == at::ScalarType::Float && weight_type == at::ScalarType::Float && // NOLINT(*)
else if (grad_type == at::ScalarType::Float && // NOLINT(*)
weight_type == at::ScalarType::Float &&
num_tensors == 4) {
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
SGDFunctor<4, float, float>(), wd, momentum, dampening, lr, nesterov,
......
......@@ -615,7 +615,9 @@ def checkpoint(
"`get_rng_state_tracker` must be passed as keyword arguments to `checkpoint`.",
DeprecationWarning, stacklevel=2,
)
distribute_saved_activations, get_rng_state_tracker, tp_group = args[:3] # pylint: disable=unbalanced-tuple-unpacking
distribute_saved_activations = args[0]
get_rng_state_tracker = args[1]
tp_group = args[2]
args = args[3:]
# Trigger the native PyTorch checkpoint if the function is not or does not contain a
......
......@@ -336,7 +336,7 @@ class FusedScaleMaskSoftmax(nn.Module):
return self.forward_fused_softmax(inp, mask, scale)
return self.forward_torch_softmax(inp, mask, scale)
def is_kernel_available(self, mask: torch.Tensor, b: int, np: int, sq: int, sk: int) -> bool: # pylint: disable=too-many-return-statements
def is_kernel_available(self, mask: torch.Tensor, b: int, np: int, sq: int, sk: int) -> bool:
"""Check FusedScaleMaskSoftmax kernel availability based on size"""
attn_batches = b * np
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment