Unverified Commit d99142a0 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Add auto-formatter (#919)



* Initial config test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* remove linters, fix clang-format
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix clang-format
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix clang-format
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Remove lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Adjust config
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* use config file
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* adjust pylintrc
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* pre-format fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Python only
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add FA module
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Update CI configs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* CRLF -> LF
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* format
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* revert accidental formatting changes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* try with sudo
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* cpp formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix pylint error properly
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* some review comments
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* lint fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* add fp8 attn include in the correct file
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* autofix PRs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 43569381
---
Language: Cpp
# BasedOnStyle: Google
AccessModifierOffset: -1
AlignAfterOpenBracket: Align
AlignArrayOfStructures: None
AlignConsecutiveAssignments:
Enabled: false
AcrossEmptyLines: false
AcrossComments: false
AlignCompound: false
AlignFunctionPointers: false
PadOperators: true
AlignConsecutiveBitFields:
Enabled: false
AcrossEmptyLines: false
AcrossComments: false
AlignCompound: false
AlignFunctionPointers: false
PadOperators: false
AlignConsecutiveDeclarations:
Enabled: false
AcrossEmptyLines: false
AcrossComments: false
AlignCompound: false
AlignFunctionPointers: false
PadOperators: false
AlignConsecutiveMacros:
Enabled: false
AcrossEmptyLines: false
AcrossComments: false
AlignCompound: false
AlignFunctionPointers: false
PadOperators: false
AlignConsecutiveShortCaseStatements:
Enabled: false
AcrossEmptyLines: false
AcrossComments: false
AlignCaseColons: false
AlignEscapedNewlines: Left
AlignOperands: Align
AlignTrailingComments:
Kind: Always
OverEmptyLines: 0
AllowAllArgumentsOnNextLine: true
AllowAllParametersOfDeclarationOnNextLine: true
AllowBreakBeforeNoexceptSpecifier: Never
AllowShortBlocksOnASingleLine: Never
AllowShortCaseLabelsOnASingleLine: false
AllowShortCompoundRequirementOnASingleLine: true
AllowShortEnumsOnASingleLine: true
AllowShortFunctionsOnASingleLine: All
AllowShortIfStatementsOnASingleLine: WithoutElse
AllowShortLambdasOnASingleLine: All
AllowShortLoopsOnASingleLine: true
AlwaysBreakAfterDefinitionReturnType: None
AlwaysBreakAfterReturnType: None
AlwaysBreakBeforeMultilineStrings: true
AlwaysBreakTemplateDeclarations: Yes
AttributeMacros:
- __capability
BinPackArguments: true
BinPackParameters: true
BitFieldColonSpacing: Both
BraceWrapping:
AfterCaseLabel: false
AfterClass: false
AfterControlStatement: Never
AfterEnum: false
AfterExternBlock: false
AfterFunction: false
AfterNamespace: false
AfterObjCDeclaration: false
AfterStruct: false
AfterUnion: false
BeforeCatch: false
BeforeElse: false
BeforeLambdaBody: false
BeforeWhile: false
IndentBraces: false
SplitEmptyFunction: true
SplitEmptyRecord: true
SplitEmptyNamespace: true
BreakAdjacentStringLiterals: true
BreakAfterAttributes: Leave
BreakAfterJavaFieldAnnotations: false
BreakArrays: true
BreakBeforeBinaryOperators: None
BreakBeforeConceptDeclarations: Always
BreakBeforeBraces: Attach
BreakBeforeInlineASMColon: OnlyMultiline
BreakBeforeTernaryOperators: true
BreakConstructorInitializers: BeforeColon
BreakInheritanceList: BeforeColon
BreakStringLiterals: true
ColumnLimit: 100
CommentPragmas: '^ IWYU pragma:'
CompactNamespaces: false
ConstructorInitializerIndentWidth: 4
ContinuationIndentWidth: 4
Cpp11BracedListStyle: true
DerivePointerAlignment: true
DisableFormat: false
EmptyLineAfterAccessModifier: Never
EmptyLineBeforeAccessModifier: LogicalBlock
ExperimentalAutoDetectBinPacking: false
FixNamespaceComments: true
ForEachMacros:
- foreach
- Q_FOREACH
- BOOST_FOREACH
IfMacros:
- KJ_IF_MAYBE
IncludeBlocks: Regroup
IncludeCategories:
- Regex: '^<ext/.*\.h>'
Priority: 2
SortPriority: 0
CaseSensitive: false
- Regex: '^<.*\.h>'
Priority: 1
SortPriority: 0
CaseSensitive: false
- Regex: '^<.*'
Priority: 2
SortPriority: 0
CaseSensitive: false
- Regex: '.*'
Priority: 3
SortPriority: 0
CaseSensitive: false
IncludeIsMainRegex: '([-_](test|unittest))?$'
IncludeIsMainSourceRegex: ''
IndentAccessModifiers: false
IndentCaseBlocks: false
IndentCaseLabels: true
IndentExternBlock: AfterExternBlock
IndentGotoLabels: true
IndentPPDirectives: None
IndentRequiresClause: true
IndentWidth: 2
IndentWrappedFunctionNames: false
InsertBraces: false
InsertNewlineAtEOF: false
InsertTrailingCommas: None
IntegerLiteralSeparator:
Binary: 0
BinaryMinDigits: 0
Decimal: 0
DecimalMinDigits: 0
Hex: 0
HexMinDigits: 0
JavaScriptQuotes: Leave
JavaScriptWrapImports: true
KeepEmptyLinesAtTheStartOfBlocks: false
KeepEmptyLinesAtEOF: false
LambdaBodyIndentation: Signature
LineEnding: DeriveLF
MacroBlockBegin: ''
MacroBlockEnd: ''
MaxEmptyLinesToKeep: 1
NamespaceIndentation: None
ObjCBinPackProtocolList: Never
ObjCBlockIndentWidth: 2
ObjCBreakBeforeNestedBlockParam: true
ObjCSpaceAfterProperty: false
ObjCSpaceBeforeProtocolList: true
PackConstructorInitializers: NextLine
PenaltyBreakAssignment: 2
PenaltyBreakBeforeFirstCallParameter: 1
PenaltyBreakComment: 300
PenaltyBreakFirstLessLess: 120
PenaltyBreakOpenParenthesis: 0
PenaltyBreakScopeResolution: 500
PenaltyBreakString: 1000
PenaltyBreakTemplateDeclaration: 10
PenaltyExcessCharacter: 1000000
PenaltyIndentedWhitespace: 0
PenaltyReturnTypeOnItsOwnLine: 200
PointerAlignment: Left
PPIndentWidth: -1
QualifierAlignment: Leave
RawStringFormats:
- Language: Cpp
Delimiters:
- cc
- CC
- cpp
- Cpp
- CPP
- 'c++'
- 'C++'
CanonicalDelimiter: ''
BasedOnStyle: google
- Language: TextProto
Delimiters:
- pb
- PB
- proto
- PROTO
EnclosingFunctions:
- EqualsProto
- EquivToProto
- PARSE_PARTIAL_TEXT_PROTO
- PARSE_TEST_PROTO
- PARSE_TEXT_PROTO
- ParseTextOrDie
- ParseTextProtoOrDie
- ParseTestProto
- ParsePartialTestProto
CanonicalDelimiter: pb
BasedOnStyle: google
ReferenceAlignment: Pointer
ReflowComments: false
RemoveBracesLLVM: false
RemoveParentheses: Leave
RemoveSemicolon: false
RequiresClausePosition: OwnLine
RequiresExpressionIndentation: OuterScope
SeparateDefinitionBlocks: Leave
ShortNamespaceLines: 1
SkipMacroDefinitionBody: false
SortIncludes: CaseSensitive
SortJavaStaticImport: Before
SortUsingDeclarations: LexicographicNumeric
SpaceAfterCStyleCast: false
SpaceAfterLogicalNot: false
SpaceAfterTemplateKeyword: true
SpaceAroundPointerQualifiers: Default
SpaceBeforeAssignmentOperators: true
SpaceBeforeCaseColon: false
SpaceBeforeCpp11BracedList: false
SpaceBeforeCtorInitializerColon: true
SpaceBeforeInheritanceColon: true
SpaceBeforeJsonColon: false
SpaceBeforeParens: ControlStatements
SpaceBeforeParensOptions:
AfterControlStatements: true
AfterForeachMacros: true
AfterFunctionDefinitionName: false
AfterFunctionDeclarationName: false
AfterIfMacros: true
AfterOverloadedOperator: false
AfterPlacementOperator: true
AfterRequiresInClause: false
AfterRequiresInExpression: false
BeforeNonEmptyParentheses: false
SpaceBeforeRangeBasedForLoopColon: true
SpaceBeforeSquareBrackets: false
SpaceInEmptyBlock: false
SpacesBeforeTrailingComments: 2
SpacesInAngles: Never
SpacesInContainerLiterals: true
SpacesInLineCommentPrefix:
Minimum: 1
Maximum: -1
SpacesInParens: Never
SpacesInParensOptions:
InCStyleCasts: false
InConditionalStatements: false
InEmptyParentheses: false
Other: false
SpacesInSquareBrackets: false
Standard: Auto
StatementAttributeLikeMacros:
- Q_EMIT
StatementMacros:
- Q_UNUSED
- QT_REQUIRE_VERSION
TabWidth: 8
UseTab: Never
VerilogBreakBetweenInstancePorts: true
WhitespaceSensitiveMacros:
- BOOST_PP_STRINGIZE
- CF_SWIFT_NAME
- NS_SWIFT_NAME
- PP_STRINGIZE
- STRINGIZE
...
......@@ -11,9 +11,8 @@ jobs:
pytorch:
name: 'PyTorch'
runs-on: ubuntu-latest
if: false # NGC PyTorch container does not fit on GitHub runner
container:
image: nvcr.io/nvidia/pytorch:23.03-py3
image: nvcr.io/nvidia/pytorch:24.05-py3
options: --user root
steps:
- name: 'Checkout'
......@@ -44,3 +43,20 @@ jobs:
NVTE_FRAMEWORK: jax
- name: 'Sanity check'
run: python tests/jax/test_sanity_import.py
paddle:
name: 'PaddlePaddle'
runs-on: ubuntu-latest
container:
image: nvcr.io/nvidia/paddlepaddle:24.05-py3
options: --user root
steps:
- name: 'Checkout'
uses: actions/checkout@v3
with:
submodules: recursive
- name: 'Build'
run: pip install . -v
env:
NVTE_FRAMEWORK: paddle
- name: 'Sanity check'
run: python tests/paddle/test_sanity_import.py
......@@ -16,22 +16,22 @@ jobs:
uses: actions/checkout@v3
- name: 'Lint'
run: |
sudo apt-get update
sudo apt-get install pip -y
export CPP_ONLY=1
export TE_PATH=.
bash ./qa/L0_pytorch_lint/test.sh
pytorch_pylint:
name: 'PyTorch Python'
runs-on: ubuntu-latest
if: false # NGC PyTorch container does not fit on GitHub runner
container:
image: nvcr.io/nvidia/pytorch:23.03-py3
options: --user root
steps:
- name: 'Checkout'
uses: actions/checkout@v3
- name: 'Lint'
run: |
pip install flash-attn==1.0.2
sudo apt-get update
sudo apt-get install pip -y
pip install torch
export PYTHON_ONLY=1
export TE_PATH=.
bash ./qa/L0_pytorch_lint/test.sh
......@@ -43,20 +43,48 @@ jobs:
uses: actions/checkout@v3
- name: 'Lint'
run: |
sudo apt-get update
sudo apt-get install pip -y
export CPP_ONLY=1
export TE_PATH=.
bash ./qa/L0_jax_lint/test.sh
jax_pylint:
name: 'JAX Python'
runs-on: ubuntu-latest
container:
image: ghcr.io/nvidia/jax:latest
options: --user root
steps:
- name: 'Checkout'
uses: actions/checkout@v3
- name: 'Lint'
run: |
sudo apt-get update
sudo apt-get install pip -y
export PYTHON_ONLY=1
export TE_PATH=.
bash ./qa/L0_jax_lint/test.sh
paddle_cpplint:
name: 'PaddlePaddle C++'
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v3
- name: 'Lint'
run: |
sudo apt-get update
sudo apt-get install pip -y
export CPP_ONLY=1
export TE_PATH=.
bash ./qa/L0_paddle_lint/test.sh
paddle_pylint:
name: 'PaddlePaddle Python'
runs-on: ubuntu-latest
steps:
- name: 'Checkout'
uses: actions/checkout@v3
- name: 'Lint'
run: |
sudo apt-get update
sudo apt-get install pip -y
pip install paddlepaddle-gpu
export PYTHON_ONLY=1
export TE_PATH=.
bash ./qa/L0_paddle_lint/test.sh
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
default_language_version:
python: python3
ci:
autofix_prs: true
autoupdate_commit_msg: '[pre-commit.ci] pre-commit suggestions'
autoupdate_schedule: quarterly
submodules: false
skip: []
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
hooks:
- id: check-merge-conflict
- id: check-added-large-files
- id: end-of-file-fixer
files: .*.(c|cc|cxx|cpp|cu|cuh|h|hpp|py)$
- id: trailing-whitespace
files: .*.(c|cc|cxx|cpp|cu|cuh|h|hpp|py)$
- repo: https://github.com/psf/black
rev: 24.4.2
hooks:
- id: black
name: Format python code
args: [--line-length=100, --preview, --enable-unstable-feature=string_processing]
types: [python]
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v18.1.6
hooks:
- id: clang-format
entry: clang-format -i
args: ["-style=file"]
files: ^transformer_engine.*\.(c|cc|cxx|cpp|cu|cuh|h|hpp)$
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# Stop searching for additional config files.
set noparent
# Limit line length.
linelength=100
# Ignore the following errors.
filter=-build/include_subdir
filter=-build/namespaces
filter=-readability/todo
filter=-build/header_guard
filter=-build/include
filter=-build/c++11
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# Stop searching for additional config files.
set noparent
# Limit line length.
linelength=100
# Ignore the following errors.
filter=-build/include_subdir
filter=-build/namespaces
filter=-readability/todo
filter=-build/header_guard
filter=-build/include
filter=-build/c++11
filter=-runtime/references
filter=-whitespace
[MASTER]
extension-pkg-whitelist=torch,
transformer_engine_torch
extension-pkg-whitelist=flash_attn_2_cuda,
torch,
transformer_engine_torch,
transformer_engine_paddle,
transformer_engine_jax
extension-pkg-allow-list=transformer_engine.transformer_engine_jax
disable=too-many-locals,
too-many-public-methods,
......@@ -24,7 +29,9 @@ disable=too-many-locals,
global-statement,
too-many-branches,
global-variable-not-assigned,
redefined-argument-from-local
redefined-argument-from-local,
line-too-long,
too-many-return-statements
[TYPECHECK]
ignored-modules=torch
......
[MASTER]
extension-pkg-whitelist=transformer_engine_jax
extension-pkg-allow-list=transformer_engine.transformer_engine_jax
disable=too-many-locals,
invalid-name,
too-many-arguments,
abstract-method,
arguments-differ,
too-many-instance-attributes,
unsubscriptable-object,
import-outside-toplevel,
too-many-statements,
import-error,
too-many-lines,
use-maxsplit-arg,
protected-access,
pointless-string-statement,
cyclic-import,
duplicate-code,
no-member,
attribute-defined-outside-init,
global-statement,
too-many-branches,
global-variable-not-assigned,
redefined-argument-from-local
......@@ -9,17 +9,15 @@ set -e
pip install cpplint==1.6.0 pylint==2.13.5
if [ -z "${PYTHON_ONLY}" ]
then
cp $TE_PATH/qa/L0_jax_lint/CPPLINT.cfg $TE_PATH
cd $TE_PATH
echo "Checking common API headers"
cpplint --root transformer_engine/common/include --recursive transformer_engine/common/include
echo "Checking C++ files"
cpplint --recursive --exclude=transformer_engine/common/include transformer_engine/common
cpplint --recursive --exclude=transformer_engine/common/include --exclude=transformer_engine/build_tools/build transformer_engine/common
cpplint --recursive transformer_engine/jax
fi
if [ -z "${CPP_ONLY}" ]
then
cp $TE_PATH/qa/L0_jax_lint/pylintrc $TE_PATH
cd $TE_PATH
echo "Checking Python files"
pylint --recursive=y transformer_engine/common transformer_engine/jax
......
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# Stop searching for additional config files.
set noparent
# Limit line length.
linelength=100
# Ignore the following errors.
filter=-build/include_subdir
filter=-build/namespaces
filter=-readability/todo
filter=-build/header_guard
filter=-build/include
filter=-build/c++11
[MASTER]
extension-pkg-whitelist=transformer_engine_paddle
disable=too-many-locals,
invalid-name,
too-many-arguments,
abstract-method,
arguments-differ,
too-many-instance-attributes,
unsubscriptable-object,
import-outside-toplevel,
too-many-statements,
import-error,
too-many-lines,
use-maxsplit-arg,
protected-access,
pointless-string-statement,
cyclic-import,
duplicate-code,
no-member,
attribute-defined-outside-init,
global-statement,
too-many-branches,
global-variable-not-assigned,
redefined-argument-from-local
......@@ -9,18 +9,16 @@ set -e
pip install cpplint==1.6.0 pylint==2.13.5
if [ -z "${PYTHON_ONLY}" ]
then
cp $TE_PATH/qa/L0_paddle_lint/CPPLINT.cfg $TE_PATH
cd $TE_PATH
echo "Checking common API headers"
cpplint --root transformer_engine/common/include --recursive transformer_engine/common/include
echo "Checking C++ files"
cpplint --recursive --exclude=transformer_engine/common/include transformer_engine/common
cpplint --recursive --exclude=transformer_engine/common/include --exclude=transformer_engine/build_tools/build transformer_engine/common
cpplint --recursive transformer_engine/paddle
fi
if [ -z "${CPP_ONLY}" ]
then
cp $TE_PATH/qa/L0_paddle_lint/pylintrc $TE_PATH
cd $TE_PATH
echo "Checking Python files"
python -m pylint --recursive=y transformer_engine/common transformer_engine/paddle
pylint --recursive=y transformer_engine/common transformer_engine/paddle
fi
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# Stop searching for additional config files.
set noparent
# Limit line length.
linelength=100
# Ignore the following errors.
filter=-build/include_subdir
filter=-build/namespaces
filter=-readability/todo
filter=-build/header_guard
filter=-build/include
filter=-build/c++11
......@@ -9,17 +9,15 @@ set -e
pip install cpplint==1.6.0 pylint==2.13.5
if [ -z "${PYTHON_ONLY}" ]
then
cp $TE_PATH/qa/L0_pytorch_lint/CPPLINT.cfg $TE_PATH
cd $TE_PATH
echo "Checking common API headers"
cpplint --root transformer_engine/common/include --recursive transformer_engine/common/include
echo "Checking C++ files"
cpplint --recursive --exclude=transformer_engine/common/include --exclude=transformer_engine/build_tools/build transformer_engine
cpplint --recursive --exclude=transformer_engine/common/include --exclude=transformer_engine/build_tools/build transformer_engine/common
cpplint --recursive transformer_engine/pytorch
fi
if [ -z "${CPP_ONLY}" ]
then
cp $TE_PATH/qa/L0_pytorch_lint/pylintrc $TE_PATH
cd $TE_PATH
echo "Checking Python files"
pylint --recursive=y transformer_engine/common transformer_engine/pytorch
......
......@@ -6,4 +6,3 @@ set -xe
: ${TE_PATH:=/opt/transformerengine}
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_*
#!/bin/bash
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
python_files=`find transformer_engine tests setup.py examples -name '*.py'`
for f in $python_files
do
black $f
done
import transformer_engine.paddle
print("OK")
......@@ -43,7 +43,6 @@ namespace fused_attn {
static void createScale(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
NVTE_QKV_Layout layout, cudnnDataType_t tensorType,
// NOLINTNEXTLINE(runtime/references)
std::vector<cudnn_frontend::Operation> &ops) {
// scale
int64_t scale_dim[4] = {1, 1, 1, 1};
......@@ -72,7 +71,6 @@ static void createScale(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t
static cudnn_frontend::Tensor createBMM1(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
NVTE_QKV_Layout layout, cudnnDataType_t tensorType,
bool zero_s,
// NOLINTNEXTLINE(runtime/references)
std::vector<cudnn_frontend::Operation> &ops) {
// Creates the necessary tensor descriptors
int64_t q_dim[4] = {b, h, s_q, d};
......@@ -132,7 +130,6 @@ static cudnn_frontend::Tensor createBMM1(int64_t b, int64_t h, int64_t s_q, int6
static cudnn_frontend::Tensor createBias(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
NVTE_QKV_Layout layout, cudnnDataType_t tensorType,
// NOLINTNEXTLINE(runtime/references)
std::vector<cudnn_frontend::Operation> &ops,
cudnn_frontend::Tensor const &prevBlockOutputTensor) {
NVTE_CHECK(ops.size() != 0, "Bias op constructed incorrectly as the first one.");
......@@ -165,7 +162,6 @@ static cudnn_frontend::Tensor createBias(int64_t b, int64_t h, int64_t s_q, int6
static cudnn_frontend::Tensor createMask(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
NVTE_QKV_Layout layout, NVTE_Mask_Type mask_type,
cudnnDataType_t tensorType,
// NOLINTNEXTLINE(runtime/references)
std::vector<cudnn_frontend::Operation> &ops,
cudnn_frontend::Tensor const &prevBlockOutputTensor,
bool is_bprop) {
......@@ -328,7 +324,6 @@ static cudnn_frontend::Tensor createMask(int64_t b, int64_t h, int64_t s_q, int6
static cudnn_frontend::Tensor createSoftmaxForward(
int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, NVTE_QKV_Layout layout,
bool enable_dropout, bool softmax_output_virtual, cudnnDataType_t tensorType,
// NOLINTNEXTLINE(runtime/references)
std::vector<cudnn_frontend::Operation> &ops,
cudnn_frontend::Tensor const &prevBlockOutputTensor) {
int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv};
......@@ -432,7 +427,6 @@ static cudnn_frontend::Tensor createSoftmaxForward(
static cudnn_frontend::Tensor createDropout(int64_t b, int64_t h, int64_t s_q, int64_t s_kv,
int64_t d, double probability,
cudnnDataType_t tensorType,
// NOLINTNEXTLINE(runtime/references)
std::vector<cudnn_frontend::Operation> &ops,
cudnn_frontend::Tensor const &prevBlockOutputTensor) {
NVTE_CHECK(ops.size() != 0, "Dropout DAG constructed incorrectly as the first one");
......@@ -512,7 +506,6 @@ static cudnn_frontend::Tensor createDropout(int64_t b, int64_t h, int64_t s_q, i
static void createBMM2(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
NVTE_QKV_Layout layout, cudnnDataType_t tensorType,
// NOLINTNEXTLINE(runtime/references)
std::vector<cudnn_frontend::Operation> &ops,
cudnn_frontend::Tensor const &prevBlockOutputTensor) {
NVTE_CHECK(ops.size() != 0, "BMM2 op constructed incorrectly as the first one");
......@@ -559,7 +552,6 @@ static void createBMM2(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t
static cudnn_frontend::Tensor createSoftmaxBackward(int64_t b, int64_t h, int64_t s_q, int64_t s_kv,
int64_t d, NVTE_QKV_Layout layout,
cudnnDataType_t tensorType,
// NOLINTNEXTLINE(runtime/references)
std::vector<cudnn_frontend::Operation> &ops,
cudnn_frontend::Tensor const &yTensor,
cudnn_frontend::Tensor const &dyTensor) {
......
......@@ -4,8 +4,6 @@
* See LICENSE for license information.
************************************************************************/
#include "transformer_engine/fused_attn.h"
#include "../common.h"
#include "utils.h"
#include "../util/system.h"
......
......@@ -9,6 +9,7 @@
*/
#include "transformer_engine/transformer_engine.h"
#include "transformer_engine/fused_attn.h"
namespace transformer_engine {
#if (CUDNN_VERSION >= 8900)
......
......@@ -73,12 +73,12 @@ def mask_to_cu_seqlens(mask: paddle.Tensor, need_kv: bool = False) -> paddle.Ten
"""Convert mask to cu_seqlens"""
assert 'bool' in str(mask.dtype), "mask must be bool dtype"
assert len(mask.shape) == 4 and mask.shape[1] == 1, "mask must be [b, 1, s_q, s_kv]"
q_actual_seqlens = paddle.sum(mask[:, :, :, 0] == False, axis=(-1, -2), dtype='int32') # pylint: disable=singleton-comparison
q_actual_seqlens = paddle.sum(mask[:, :, :, 0].logical_not(), axis=(-1, -2), dtype='int32')
q_cu_seqlens = paddle.cumsum(q_actual_seqlens)
q_cu_seqlens = paddle.concat([paddle.zeros([1], dtype=paddle.int32), q_cu_seqlens], axis=0)
if not need_kv:
return q_cu_seqlens, None
kv_actual_seqlens = paddle.sum(mask[:, :, 0, :] == False, axis=(-1, -2), dtype='int32') # pylint: disable=singleton-comparison
kv_actual_seqlens = paddle.sum(mask[:, :, 0, :].logical_not(), axis=(-1, -2), dtype='int32')
kv_cu_seqlens = paddle.cumsum(kv_actual_seqlens)
kv_cu_seqlens = paddle.concat([paddle.zeros([1], dtype=paddle.int32), kv_cu_seqlens], axis=0)
return q_cu_seqlens, kv_cu_seqlens
......
......@@ -79,10 +79,10 @@ _flash_attn_2_4_plus = _flash_attn_version >= PkgVersion("2.4")
_flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1")
if _flash_attn_version >= _flash_attn_version_required:
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_forward_func # pylint: disable=no-name-in-module
from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd # pylint: disable=no-name-in-module
from flash_attn.flash_attn_interface import _flash_attn_varlen_forward as _flash_attn_forward # pylint: disable=no-name-in-module,ungrouped-imports
from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward # pylint: disable=no-name-in-module
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_forward_func
from flash_attn.flash_attn_interface import _flash_attn_varlen_forward as _flash_attn_forward
from flash_attn.flash_attn_interface import _flash_attn_varlen_backward as _flash_attn_backward
from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd
META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT
META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1
......@@ -115,6 +115,7 @@ _alibi_cache = {
__all__ = ["DotProductAttention", "InferenceParams", "MultiheadAttention"]
class InferenceParams: # pylint: disable=too-few-public-methods
"""
Inference parameters that are passed to the main model in order
......@@ -1404,8 +1405,8 @@ def attn_forward_func_with_cp(
assert (qkv_format != 'thd' or \
not use_fused_attention or \
attn_mask_type in ["padding", "padding_causal"]
), f"""Context parallelism is not supported for {attn_mask_type} mask type and """ \
f"""{qkv_format} format with {"FusedAttention" if use_fused_attention else "FlashAttention"}!"""
), f"Context parallelism is not supported for {attn_mask_type} mask type and " \
f"{qkv_format} format with {'FusedAttention' if use_fused_attention else 'FlashAttention'}!"
assert (attn_bias is None or (use_fused_attention and "padding" not in attn_mask_type)
), """Attention bias is only supported with FusedAttention and "causal" """ \
"""or "no_mask" mask types!"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment