Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
a0e873f8
Commit
a0e873f8
authored
Jan 16, 2025
by
silencealiang
Browse files
增加compile,减少图编译,减少cpu开销
parent
06b52e5b
Pipeline
#2222
passed with stage
Changes
3
Pipelines
1
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
18 additions
and
4 deletions
+18
-4
megatron/core/jit.py
megatron/core/jit.py
+14
-0
megatron/core/tensor_parallel/layers.py
megatron/core/tensor_parallel/layers.py
+1
-1
megatron/core/tensor_parallel/mappings.py
megatron/core/tensor_parallel/mappings.py
+3
-3
No files found.
megatron/core/jit.py
View file @
a0e873f8
...
...
@@ -8,3 +8,17 @@ jit_fuser = torch.jit.script
# nvFuser is deprecated in PyTorch JIT starting from 2.2
if
is_torch_min_version
(
"2.2.0a0"
):
jit_fuser
=
torch
.
compile
# Decorator to disable Torch Dynamo
# See: https://github.com/NVIDIA/TransformerEngine/issues/308
no_torch_dynamo
=
lambda
recursive
=
True
:
lambda
func
:
func
if
torch
.
__version__
>=
"2"
:
import
torch._dynamo
if
torch
.
__version__
>=
"2.1"
:
no_torch_dynamo
=
lambda
recursive
=
True
:
lambda
f
:
torch
.
_dynamo
.
disable
(
f
,
recursive
=
recursive
)
else
:
# no "recursive" option in pyTorch 2.0 - it acts as if recursive was True
no_torch_dynamo
=
lambda
recursive
=
True
:
torch
.
_dynamo
.
disable
\ No newline at end of file
megatron/core/tensor_parallel/layers.py
View file @
a0e873f8
...
...
@@ -237,7 +237,7 @@ class VocabParallelEmbedding(torch.nn.Module):
)
if
config
.
perform_initialization
:
_initialize_affine_weight_gpu
(
self
.
weight
,
init_method
,
partition_dim
=
0
,
stride
=
1
)
@
torch
.
compile
(
mode
=
'max-autotune-no-cudagraphs'
)
def
forward
(
self
,
input_
):
"""Forward.
...
...
megatron/core/tensor_parallel/mappings.py
View file @
a0e873f8
...
...
@@ -462,13 +462,13 @@ class _AllToAll(torch.autograd.Function):
# -----------------
# Helper functions.
# -----------------
from
megatron.core.jit
import
no_torch_dynamo
def
copy_to_tensor_model_parallel_region
(
input_
):
"""Wrapper for autograd function: forward: copy, backward allreduce"""
return
_CopyToModelParallelRegion
.
apply
(
input_
)
@
no_torch_dynamo
()
def
reduce_from_tensor_model_parallel_region
(
input_
):
"""Wrapper for autograd function: forward: all reduce, backward copy"""
return
_ReduceFromModelParallelRegion
.
apply
(
input_
)
...
...
@@ -501,7 +501,7 @@ def gather_from_sequence_parallel_region(
input_
,
tensor_parallel_output_grad
,
group
,
output_split_sizes
,
use_global_buffer
)
@
no_torch_dynamo
()
def
reduce_scatter_to_sequence_parallel_region
(
input_
,
group
=
None
,
input_split_sizes
=
None
,
use_global_buffer
=
False
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment