Commit d3a416cd authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'core-noop' into 'core'

Remove noop used to try to force scheduling and check for environment variable instead.

See merge request ADLR/megatron-lm!463
parents 5da3bb92 bdd97312
...@@ -313,6 +313,18 @@ def validate_args(args, defaults={}): ...@@ -313,6 +313,18 @@ def validate_args(args, defaults={}):
if args.sequence_parallel: if args.sequence_parallel:
args.async_tensor_model_parallel_allreduce = False args.async_tensor_model_parallel_allreduce = False
if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1":
if args.sequence_parallel:
raise RuntimeError(
"Using sequence parallelism requires setting the environment variable "
"CUDA_DEVICE_MAX_CONNECTIONS to 1")
if args.async_tensor_model_parallel_allreduce:
raise RuntimeError(
"Using async gradient all reduce requires setting the environment "
"variable CUDA_DEVICE_MAX_CONNECTIONS to 1")
_print_args(args) _print_args(args)
return args return args
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# repo: https://github.com/pytorch/pytorch # repo: https://github.com/pytorch/pytorch
import math import math
import os
from typing import Optional from typing import Optional
import warnings import warnings
...@@ -210,10 +211,7 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -210,10 +211,7 @@ class VocabParallelEmbedding(torch.nn.Module):
class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
""" """See linear_with_grad_accumulation_and_async_allreduce"""
Linear layer execution with asynchronous communication and gradient accumulation
fusion in backprop.
"""
@staticmethod @staticmethod
def forward(ctx, input, weight, bias, gradient_accumulation_fusion, def forward(ctx, input, weight, bias, gradient_accumulation_fusion,
...@@ -261,9 +259,8 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): ...@@ -261,9 +259,8 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
input, input,
group=get_tensor_model_parallel_group(), async_op=True) group=get_tensor_model_parallel_group(), async_op=True)
# Delay the start of intput gradient computation shortly (3us) to have # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# gather scheduled first and have GPU resources allocated # gather is scheduled before the input gradient computation
_ = torch.empty(1, device=grad_output.device) + 1
total_input = all_gather_buffer total_input = all_gather_buffer
else: else:
total_input = input total_input = input
...@@ -282,9 +279,8 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): ...@@ -282,9 +279,8 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
# Asynchronous all-reduce # Asynchronous all-reduce
handle = torch.distributed.all_reduce( handle = torch.distributed.all_reduce(
grad_input, group=get_tensor_model_parallel_group(), async_op=True) grad_input, group=get_tensor_model_parallel_group(), async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# all-reduce scheduled first and have GPU resources allocated # all-reduce is scheduled before the weight gradient computation
_ = torch.empty(1, device=grad_output.device) + 1
if ctx.sequence_parallel: if ctx.sequence_parallel:
assert not ctx.async_grad_allreduce assert not ctx.async_grad_allreduce
...@@ -296,9 +292,8 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): ...@@ -296,9 +292,8 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
handle = torch.distributed._reduce_scatter_base(sub_grad_input, grad_input, handle = torch.distributed._reduce_scatter_base(sub_grad_input, grad_input,
group=get_tensor_model_parallel_group(), group=get_tensor_model_parallel_group(),
async_op=True) async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# reduce scatter scheduled first and have GPU resources allocated # reduce scatter is scheduled before the weight gradient computation
_ = torch.empty(1, device=grad_output.device) + 1
if ctx.gradient_accumulation_fusion: if ctx.gradient_accumulation_fusion:
...@@ -330,6 +325,58 @@ def linear_with_grad_accumulation_and_async_allreduce( ...@@ -330,6 +325,58 @@ def linear_with_grad_accumulation_and_async_allreduce(
async_grad_allreduce: bool, async_grad_allreduce: bool,
sequence_parallel_enabled: bool, sequence_parallel_enabled: bool,
) -> torch.Tensor: ) -> torch.Tensor:
"""Linear layer execution with asynchronous communication and
gradient accumulation fusion in backprop.
This has the option to accumulate the result of backprop
calculation into an existing gradient buffer, preventing the need
to do an additional addition kernel after the gradient
calculation.
Additionally, the tensor parallel all reduce of the input
gradients can be done asynchronously with the calculation of
the weight gradients.
In the case of sequence parallelism, the reduce scatter of the
input gradients is done asynchronously with the calcluation of the
weight gradients.
Use of this module requires that the environment variable
CUDA_DEVICE_MAX_CONNECTIONS=1. There are a few collective
operations, noted in the code, that should be scheduled before
compute kernels to overlap the communication with the computation,
which is necessary for a speedup but not for correctness so that
ordering isn't imposed by the scheduler. Setting
CUDA_DEVICE_MAX_CONNECTIONS=1 forces the kernels to be scheduled
in the order they are called.
Arguments:
input (torch.Tensor required): input like torch.nn.functional.linear
weight (torch.Tensor required): weight like torch.nn.functional.linear
bias (torch.Tensor optional): bias like torch.nn.functional.linear
gradient_accumulation_fusion (bool required): Perform the gradient
accumulation fusion, requires the custom CUDA extension
fused_weight_gradient_mlp_cuda module. To use
gradient_accumulation_fusion you must install APEX with
--cpp_ext and --cuda_ext. For example: "pip install
--global-option=\"--cpp_ext\" --global-option=\"--cuda_ext .\"
" Note that the extension requires CUDA>=11. Otherwise, you
must turn off gradient accumulation fusion."
async_grad_allreduce (bool required): Do the allreduce of input
gradients asyncronously with the computation of weight
gradients. If sequence_parallel_enabled is True, this must be
False, as no all reduce is performed.
sequence_parallel_enabled (bool required): Indicates that sequence
parallelism is used and thus in the forward pass the input is
all gathered, and the backward pass the input gradients are
reduce scattered.
"""
args = [ args = [
input, input,
weight, weight,
...@@ -338,9 +385,26 @@ def linear_with_grad_accumulation_and_async_allreduce( ...@@ -338,9 +385,26 @@ def linear_with_grad_accumulation_and_async_allreduce(
async_grad_allreduce, async_grad_allreduce,
sequence_parallel_enabled, sequence_parallel_enabled,
] ]
if not linear_with_grad_accumulation_and_async_allreduce.warned:
if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1":
if sequence_parallel_enabled:
warnings.warn(
"When using sequence parallelism it is recommended to set the "
"environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for "
"maximum speedup")
linear_with_grad_accumulation_and_async_allreduce.warned = True
if async_grad_allreduce:
warnings.warn(
"When using async grad allreduce it is recommended to set the "
"environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for "
"maximum speedup")
linear_with_grad_accumulation_and_async_allreduce.warned = True
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
return LinearWithGradAccumulationAndAsyncCommunication.apply(*args) return LinearWithGradAccumulationAndAsyncCommunication.apply(*args)
linear_with_grad_accumulation_and_async_allreduce.warned = False
class ColumnParallelLinear(torch.nn.Module): class ColumnParallelLinear(torch.nn.Module):
"""Linear layer with column parallelism. """Linear layer with column parallelism.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment