Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhougaofeng
internlm2-math-7B
Commits
fe9a149a
"src/lib/vscode:/vscode.git/clone" did not exist on "8f35f85917758105705c3721030fc93c31b8677a"
Commit
fe9a149a
authored
Jun 11, 2024
by
zhougaofeng
Browse files
Upload New File
parent
7fae61ae
Pipeline
#1167
canceled with stages
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
55 additions
and
0 deletions
+55
-0
src/llmfactory/model/utils/attention.py
src/llmfactory/model/utils/attention.py
+55
-0
No files found.
src/llmfactory/model/utils/attention.py
0 → 100644
View file @
fe9a149a
from
typing
import
TYPE_CHECKING
from
...extras.logging
import
get_logger
from
...extras.packages
import
is_flash_attn2_available
,
is_sdpa_available
if
TYPE_CHECKING
:
from
transformers
import
PretrainedConfig
from
...hparams
import
ModelArguments
logger
=
get_logger
(
__name__
)
def
configure_attn_implementation
(
config
:
"PretrainedConfig"
,
model_args
:
"ModelArguments"
)
->
None
:
if
model_args
.
flash_attn
==
"auto"
:
return
elif
model_args
.
flash_attn
==
"off"
:
requested_attn_implementation
=
"eager"
elif
model_args
.
flash_attn
==
"sdpa"
:
if
not
is_sdpa_available
():
logger
.
warning
(
"torch>=2.1.1 is required for SDPA attention."
)
return
requested_attn_implementation
=
"sdpa"
elif
model_args
.
flash_attn
==
"fa2"
:
if
not
is_flash_attn2_available
():
logger
.
warning
(
"FlashAttention-2 is not installed."
)
return
requested_attn_implementation
=
"flash_attention_2"
else
:
raise
NotImplementedError
(
"Unknown attention type: {}"
.
format
(
model_args
.
flash_attn
))
if
getattr
(
config
,
"model_type"
,
None
)
==
"internlm2"
:
# special case for custom models
setattr
(
config
,
"attn_implementation"
,
requested_attn_implementation
)
else
:
setattr
(
config
,
"_attn_implementation"
,
requested_attn_implementation
)
def
print_attn_implementation
(
config
:
"PretrainedConfig"
)
->
None
:
if
getattr
(
config
,
"model_type"
,
None
)
==
"internlm2"
:
# special case for custom models
attn_implementation
=
getattr
(
config
,
"attn_implementation"
,
None
)
else
:
attn_implementation
=
getattr
(
config
,
"_attn_implementation"
,
None
)
if
attn_implementation
==
"flash_attention_2"
:
logger
.
info
(
"Using FlashAttention-2 for faster training and inference."
)
elif
attn_implementation
==
"sdpa"
:
logger
.
info
(
"Using torch SDPA for faster training and inference."
)
else
:
logger
.
info
(
"Using vanilla attention implementation."
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment