Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
bdd97312
Commit
bdd97312
authored
Oct 14, 2022
by
Jared Casper
Browse files
Remove noop used to try to force scheduling and check for environment variable instead.
parent
5da3bb92
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
90 additions
and
14 deletions
+90
-14
megatron/arguments.py
megatron/arguments.py
+12
-0
megatron/core/tensor_parallel/layers.py
megatron/core/tensor_parallel/layers.py
+78
-14
No files found.
megatron/arguments.py
View file @
bdd97312
...
@@ -313,6 +313,18 @@ def validate_args(args, defaults={}):
...
@@ -313,6 +313,18 @@ def validate_args(args, defaults={}):
if
args
.
sequence_parallel
:
if
args
.
sequence_parallel
:
args
.
async_tensor_model_parallel_allreduce
=
False
args
.
async_tensor_model_parallel_allreduce
=
False
if
os
.
environ
.
get
(
'CUDA_DEVICE_MAX_CONNECTIONS'
)
!=
"1"
:
if
args
.
sequence_parallel
:
raise
RuntimeError
(
"Using sequence parallelism requires setting the environment variable "
"CUDA_DEVICE_MAX_CONNECTIONS to 1"
)
if
args
.
async_tensor_model_parallel_allreduce
:
raise
RuntimeError
(
"Using async gradient all reduce requires setting the environment "
"variable CUDA_DEVICE_MAX_CONNECTIONS to 1"
)
_print_args
(
args
)
_print_args
(
args
)
return
args
return
args
...
...
megatron/core/tensor_parallel/layers.py
View file @
bdd97312
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
# repo: https://github.com/pytorch/pytorch
# repo: https://github.com/pytorch/pytorch
import
math
import
math
import
os
from
typing
import
Optional
from
typing
import
Optional
import
warnings
import
warnings
...
@@ -210,10 +211,7 @@ class VocabParallelEmbedding(torch.nn.Module):
...
@@ -210,10 +211,7 @@ class VocabParallelEmbedding(torch.nn.Module):
class
LinearWithGradAccumulationAndAsyncCommunication
(
torch
.
autograd
.
Function
):
class
LinearWithGradAccumulationAndAsyncCommunication
(
torch
.
autograd
.
Function
):
"""
"""See linear_with_grad_accumulation_and_async_allreduce"""
Linear layer execution with asynchronous communication and gradient accumulation
fusion in backprop.
"""
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
input
,
weight
,
bias
,
gradient_accumulation_fusion
,
def
forward
(
ctx
,
input
,
weight
,
bias
,
gradient_accumulation_fusion
,
...
@@ -261,9 +259,8 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
...
@@ -261,9 +259,8 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
input
,
input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
# Delay the start of intput gradient computation shortly (3us) to have
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# gather scheduled first and have GPU resources allocated
# gather is scheduled before the input gradient computation
_
=
torch
.
empty
(
1
,
device
=
grad_output
.
device
)
+
1
total_input
=
all_gather_buffer
total_input
=
all_gather_buffer
else
:
else
:
total_input
=
input
total_input
=
input
...
@@ -282,9 +279,8 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
...
@@ -282,9 +279,8 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
# Asynchronous all-reduce
# Asynchronous all-reduce
handle
=
torch
.
distributed
.
all_reduce
(
handle
=
torch
.
distributed
.
all_reduce
(
grad_input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
grad_input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
# Delay the start of weight gradient computation shortly (3us) to have
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# all-reduce scheduled first and have GPU resources allocated
# all-reduce is scheduled before the weight gradient computation
_
=
torch
.
empty
(
1
,
device
=
grad_output
.
device
)
+
1
if
ctx
.
sequence_parallel
:
if
ctx
.
sequence_parallel
:
assert
not
ctx
.
async_grad_allreduce
assert
not
ctx
.
async_grad_allreduce
...
@@ -296,9 +292,8 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
...
@@ -296,9 +292,8 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
handle
=
torch
.
distributed
.
_reduce_scatter_base
(
sub_grad_input
,
grad_input
,
handle
=
torch
.
distributed
.
_reduce_scatter_base
(
sub_grad_input
,
grad_input
,
group
=
get_tensor_model_parallel_group
(),
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
async_op
=
True
)
# Delay the start of weight gradient computation shortly (3us) to have
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# reduce scatter scheduled first and have GPU resources allocated
# reduce scatter is scheduled before the weight gradient computation
_
=
torch
.
empty
(
1
,
device
=
grad_output
.
device
)
+
1
if
ctx
.
gradient_accumulation_fusion
:
if
ctx
.
gradient_accumulation_fusion
:
...
@@ -330,6 +325,58 @@ def linear_with_grad_accumulation_and_async_allreduce(
...
@@ -330,6 +325,58 @@ def linear_with_grad_accumulation_and_async_allreduce(
async_grad_allreduce
:
bool
,
async_grad_allreduce
:
bool
,
sequence_parallel_enabled
:
bool
,
sequence_parallel_enabled
:
bool
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Linear layer execution with asynchronous communication and
gradient accumulation fusion in backprop.
This has the option to accumulate the result of backprop
calculation into an existing gradient buffer, preventing the need
to do an additional addition kernel after the gradient
calculation.
Additionally, the tensor parallel all reduce of the input
gradients can be done asynchronously with the calculation of
the weight gradients.
In the case of sequence parallelism, the reduce scatter of the
input gradients is done asynchronously with the calcluation of the
weight gradients.
Use of this module requires that the environment variable
CUDA_DEVICE_MAX_CONNECTIONS=1. There are a few collective
operations, noted in the code, that should be scheduled before
compute kernels to overlap the communication with the computation,
which is necessary for a speedup but not for correctness so that
ordering isn't imposed by the scheduler. Setting
CUDA_DEVICE_MAX_CONNECTIONS=1 forces the kernels to be scheduled
in the order they are called.
Arguments:
input (torch.Tensor required): input like torch.nn.functional.linear
weight (torch.Tensor required): weight like torch.nn.functional.linear
bias (torch.Tensor optional): bias like torch.nn.functional.linear
gradient_accumulation_fusion (bool required): Perform the gradient
accumulation fusion, requires the custom CUDA extension
fused_weight_gradient_mlp_cuda module. To use
gradient_accumulation_fusion you must install APEX with
--cpp_ext and --cuda_ext. For example: "pip install
--global-option=
\"
--cpp_ext
\"
--global-option=
\"
--cuda_ext .
\"
" Note that the extension requires CUDA>=11. Otherwise, you
must turn off gradient accumulation fusion."
async_grad_allreduce (bool required): Do the allreduce of input
gradients asyncronously with the computation of weight
gradients. If sequence_parallel_enabled is True, this must be
False, as no all reduce is performed.
sequence_parallel_enabled (bool required): Indicates that sequence
parallelism is used and thus in the forward pass the input is
all gathered, and the backward pass the input gradients are
reduce scattered.
"""
args
=
[
args
=
[
input
,
input
,
weight
,
weight
,
...
@@ -338,9 +385,26 @@ def linear_with_grad_accumulation_and_async_allreduce(
...
@@ -338,9 +385,26 @@ def linear_with_grad_accumulation_and_async_allreduce(
async_grad_allreduce
,
async_grad_allreduce
,
sequence_parallel_enabled
,
sequence_parallel_enabled
,
]
]
if
not
linear_with_grad_accumulation_and_async_allreduce
.
warned
:
if
os
.
environ
.
get
(
'CUDA_DEVICE_MAX_CONNECTIONS'
)
!=
"1"
:
if
sequence_parallel_enabled
:
warnings
.
warn
(
"When using sequence parallelism it is recommended to set the "
"environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for "
"maximum speedup"
)
linear_with_grad_accumulation_and_async_allreduce
.
warned
=
True
if
async_grad_allreduce
:
warnings
.
warn
(
"When using async grad allreduce it is recommended to set the "
"environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for "
"maximum speedup"
)
linear_with_grad_accumulation_and_async_allreduce
.
warned
=
True
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
return
LinearWithGradAccumulationAndAsyncCommunication
.
apply
(
*
args
)
return
LinearWithGradAccumulationAndAsyncCommunication
.
apply
(
*
args
)
linear_with_grad_accumulation_and_async_allreduce
.
warned
=
False
class
ColumnParallelLinear
(
torch
.
nn
.
Module
):
class
ColumnParallelLinear
(
torch
.
nn
.
Module
):
"""Linear layer with column parallelism.
"""Linear layer with column parallelism.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment