Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
9c5a830f
Commit
9c5a830f
authored
Feb 22, 2022
by
Jared Casper
Browse files
Merge branch 'slym/grad_accum_fusion' into 'main'
Gradient accumulation fusion See merge request ADLR/megatron-lm!394
parents
0ed2f6ac
b5726555
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
311 additions
and
38 deletions
+311
-38
megatron/arguments.py
megatron/arguments.py
+16
-2
megatron/fused_kernels/__init__.py
megatron/fused_kernels/__init__.py
+10
-0
megatron/fused_kernels/fused_weight_gradient_dense.cpp
megatron/fused_kernels/fused_weight_gradient_dense.cpp
+47
-0
megatron/fused_kernels/fused_weight_gradient_dense.cu
megatron/fused_kernels/fused_weight_gradient_dense.cu
+157
-0
megatron/fused_kernels/type_shim.h
megatron/fused_kernels/type_shim.h
+26
-0
megatron/model/distributed.py
megatron/model/distributed.py
+2
-2
megatron/model/language_model.py
megatron/model/language_model.py
+11
-5
megatron/mpu/__init__.py
megatron/mpu/__init__.py
+1
-0
megatron/mpu/layers.py
megatron/mpu/layers.py
+41
-29
No files found.
megatron/arguments.py
View file @
9c5a830f
...
...
@@ -172,6 +172,14 @@ def parse_args(extra_args_provider=None, defaults={},
if
args
.
accumulate_allreduce_grads_in_fp32
:
assert
args
.
DDP_impl
==
'local'
assert
args
.
use_contiguous_buffers_in_local_ddp
else
:
if
args
.
gradient_accumulation_fusion
:
args
.
gradient_accumulation_fusion
=
False
if
args
.
rank
==
0
:
print
(
'Gradient accumulation fusion to linear layer weight '
'gradient computation is supported only with fp32 '
'gradient accumulation. Setting gradient_accumulation_fusion '
'to False'
,
flush
=
True
)
# For torch DDP, we do not use contiguous buffer
if
args
.
DDP_impl
==
'torch'
:
...
...
@@ -521,15 +529,21 @@ def _add_training_args(parser):
choices
=
[
'single'
,
'cyclic'
],
help
=
'Single pass vs multiple pass data loader'
)
group
.
add_argument
(
'--no-async-tensor-model-parallel-allreduce'
,
action
=
'store_
tru
e'
,
action
=
'store_
fals
e'
,
help
=
'Disable asynchronous execution of '
'tensor-model-parallel all-reduce with weight '
'gradient compuation of a column-linear layer.'
)
'gradient compuation of a column-linear layer.'
,
dest
=
'async_tensor_model_parallel_allreduce'
)
group
.
add_argument
(
'--no-persist-layer-norm'
,
action
=
'store_true'
,
help
=
'Disable using persistent fused layer norm kernel. '
'This kernel supports only a set of hidden sizes. Please '
'check persist_ln_hidden_sizes if your hidden '
'size is supported.'
)
group
.
add_argument
(
'--no-gradient-accumulation-fusion'
,
action
=
'store_false'
,
help
=
'Disable fusing gradient accumulation to weight '
'gradient computation of linear layers'
,
dest
=
'gradient_accumulation_fusion'
)
return
parser
...
...
megatron/fused_kernels/__init__.py
View file @
9c5a830f
...
...
@@ -94,6 +94,16 @@ def load(args):
fused_mix_prec_layer_norm_cuda
=
_cpp_extention_load_helper
(
"fused_mix_prec_layer_norm_cuda"
,
sources
,
extra_cuda_flags
)
# =================================
# Fused gradient accumulation to weight gradient computation of linear layer
# =================================
if
args
.
gradient_accumulation_fusion
:
sources
=
[
srcpath
/
'fused_weight_gradient_dense.cpp'
,
srcpath
/
'fused_weight_gradient_dense.cu'
]
fused_dense_cuda
=
_cpp_extention_load_helper
(
"fused_dense_cuda"
,
sources
,
[])
def
_get_cuda_bare_metal_version
(
cuda_dir
):
raw_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
...
...
megatron/fused_kernels/fused_weight_gradient_dense.cpp
0 → 100644
View file @
9c5a830f
#include <torch/torch.h>
#include <torch/extension.h>
#include <vector>
#include <stdio.h>
#include "type_shim.h"
template
<
typename
T
>
int
wgrad_gemm_accum_fp32_cuda
(
T
*
input
,
T
*
d_output
,
float
*
d_weight
,
int
in_dim
,
int
hidden_dim
,
int
out_dim
);
void
wgrad_gemm_accum_fp32
(
const
at
::
Tensor
input
,
const
at
::
Tensor
d_output
,
at
::
Tensor
d_weight
)
{
at
::
Tensor
input_2d
,
d_output_2d
;
// input tensor: collapse to the first dim
auto
in_sizes
=
input
.
sizes
();
if
(
input
.
dim
()
>
2
)
{
input_2d
=
input
.
view
({
-
1
,
in_sizes
[
in_sizes
.
size
()
-
1
]});
}
else
{
input_2d
=
input
;
}
// d_output tensor: collapse to the first dim
auto
d_out_sizes
=
d_output
.
sizes
();
if
(
d_output
.
dim
()
>
2
)
{
d_output_2d
=
d_output
.
view
({
-
1
,
d_out_sizes
[
d_out_sizes
.
size
()
-
1
]});
}
else
{
d_output_2d
=
d_output
;
}
int
hidden_dim
=
input_2d
.
size
(
0
);
int
in_dim
=
input_2d
.
size
(
1
);
int
out_dim
=
d_weight
.
size
(
0
);
DISPATCH_HALF_BFLOAT_AND_FLOAT
(
input_2d
.
scalar_type
(),
"wgrad_gemm_accum_fp32"
,
int
result
=
wgrad_gemm_accum_fp32_cuda
<
scalar_t
>
(
input_2d
.
data_ptr
<
scalar_t
>
(),
d_output_2d
.
data_ptr
<
scalar_t
>
(),
d_weight
.
data_ptr
<
float
>
(),
in_dim
,
hidden_dim
,
out_dim
);
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"wgrad_gemm_accum_fp32"
,
&
wgrad_gemm_accum_fp32
,
"wgrad gemm accum in fp32"
);
}
megatron/fused_kernels/fused_weight_gradient_dense.cu
0 → 100644
View file @
9c5a830f
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <torch/torch.h>
/* Includes, cuda */
#include <cublas_v2.h>
#include <cuda_runtime.h>
// BF16 Tensor core wrapper around cublas GEMMEx
cublasStatus_t
gemmex_wrapper
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
at
::
BFloat16
*
A
,
int
lda
,
at
::
BFloat16
*
B
,
int
ldb
,
const
float
*
beta
,
float
*
C
,
int
ldc
)
{
return
cublasGemmEx
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
CUDA_R_16BF
,
lda
,
B
,
CUDA_R_16BF
,
ldb
,
beta
,
C
,
CUDA_R_32F
,
ldc
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
}
// FP16 Tensor core wrapper around cublas GEMMEx
cublasStatus_t
gemmex_wrapper
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
at
::
Half
*
A
,
int
lda
,
at
::
Half
*
B
,
int
ldb
,
const
float
*
beta
,
float
*
C
,
int
ldc
)
{
return
cublasGemmEx
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
CUDA_R_16F
,
lda
,
B
,
CUDA_R_16F
,
ldb
,
beta
,
C
,
CUDA_R_32F
,
ldc
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
}
// FP32 Tensor core wrapper around cublas GEMMEx
cublasStatus_t
gemmex_wrapper
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
float
*
A
,
int
lda
,
float
*
B
,
int
ldb
,
const
float
*
beta
,
float
*
C
,
int
ldc
)
{
return
cublasGemmEx
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
CUDA_R_32F
,
lda
,
B
,
CUDA_R_32F
,
ldb
,
beta
,
C
,
CUDA_R_32F
,
ldc
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
}
template
<
typename
T
>
int
wgrad_gemm_accum_fp32_cuda
(
T
*
input
,
T
*
d_output
,
float
*
d_weight
,
int
in_dim
,
int
hidden_dim
,
int
out_dim
)
{
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
cudaStream_t
stream
;
cublasGetStream
(
handle
,
&
stream
);
const
float
alpha
=
1.0
;
const
float
beta
=
1.0
;
int
status
=
1
;
status
=
gemmex_wrapper
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
in_dim
,
out_dim
,
hidden_dim
,
&
alpha
,
input
,
in_dim
,
d_output
,
out_dim
,
&
beta
,
d_weight
,
in_dim
);
return
status
;
}
template
int
wgrad_gemm_accum_fp32_cuda
<
at
::
Half
>(
at
::
Half
*
input
,
at
::
Half
*
d_output
,
float
*
d_weight
,
int
in_dim
,
int
hidden_dim
,
int
out_dim
);
template
int
wgrad_gemm_accum_fp32_cuda
<
at
::
BFloat16
>(
at
::
BFloat16
*
input
,
at
::
BFloat16
*
d_output
,
float
*
d_weight
,
int
in_dim
,
int
hidden_dim
,
int
out_dim
);
template
int
wgrad_gemm_accum_fp32_cuda
<
float
>(
float
*
input
,
float
*
d_output
,
float
*
d_weight
,
int
in_dim
,
int
hidden_dim
,
int
out_dim
);
megatron/fused_kernels/type_shim.h
View file @
9c5a830f
...
...
@@ -39,6 +39,32 @@
}
#define DISPATCH_HALF_BFLOAT_AND_FLOAT(TYPE, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Half: \
{ \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch(TYPEIN) \
...
...
megatron/model/distributed.py
View file @
9c5a830f
...
...
@@ -164,13 +164,13 @@ class DistributedDataParallel(DistributedDataParallelBase):
grad_acc
.
register_hook
(
self
.
_make_param_hook
(
param
))
self
.
grad_accs
.
append
(
grad_acc
)
def
_make_param_hook
(
self
,
param
):
"""Create the all-reduce hook for backprop."""
# Hook used for back-prop.
def
param_hook
(
*
unused
):
# Add the gradient to the buffer.
if
param
.
grad
.
data
is
not
None
:
if
param
.
grad
is
not
None
:
# The gradient function of linear layers is fused with GEMMs
param
.
main_grad
.
add_
(
param
.
grad
.
data
)
# Now we can deallocate grad memory.
param
.
grad
=
None
...
...
megatron/model/language_model.py
View file @
9c5a830f
...
...
@@ -29,13 +29,19 @@ from megatron.model.utils import init_method_normal, scaled_init_method_normal
def
parallel_lm_logits
(
input_
,
word_embeddings_weight
,
parallel_output
,
bias
=
None
):
"""LM logits using word embedding weights."""
args
=
get_args
()
# Parallel logits.
input_parallel
=
mpu
.
copy_to_tensor_model_parallel_region
(
input_
)
# Matrix multiply.
if
bias
is
None
:
logits_parallel
=
F
.
linear
(
input_parallel
,
word_embeddings_weight
)
if
args
.
async_tensor_model_parallel_allreduce
:
input_parallel
=
input_
async_grad_allreduce
=
mpu
.
get_tensor_model_parallel_world_size
()
>
1
else
:
logits_parallel
=
F
.
linear
(
input_parallel
,
word_embeddings_weight
,
bias
)
input_parallel
=
mpu
.
copy_to_tensor_model_parallel_region
(
input_
)
async_grad_allreduce
=
False
# Matrix multiply.
logits_parallel
=
mpu
.
LinearWithGradAccumulationAndAsyncAllreduce
.
apply
(
input_parallel
,
word_embeddings_weight
,
bias
,
args
.
gradient_accumulation_fusion
,
async_grad_allreduce
)
# Gather if needed.
if
parallel_output
:
return
logits_parallel
...
...
megatron/mpu/__init__.py
View file @
9c5a830f
...
...
@@ -49,6 +49,7 @@ from .initialize import get_virtual_pipeline_model_parallel_rank, set_virtual_pi
from
.initialize
import
initialize_model_parallel
from
.initialize
import
model_parallel_is_initialized
from
.layers
import
LinearWithGradAccumulationAndAsyncAllreduce
from
.layers
import
ColumnParallelLinear
from
.layers
import
RowParallelLinear
from
.layers
import
VocabParallelEmbedding
...
...
megatron/mpu/layers.py
View file @
9c5a830f
...
...
@@ -199,15 +199,18 @@ class VocabParallelEmbedding(torch.nn.Module):
return
output
class
ColumnParallelLinearWith
AsyncAllreduce
(
torch
.
autograd
.
Function
):
class
LinearWithGradAccumulationAnd
AsyncAllreduce
(
torch
.
autograd
.
Function
):
"""
Column-parallel l
inear layer execution with asynchronous all-reduce
execut
ion in backprop.
L
inear layer execution with asynchronous all-reduce
and gradient accumulation
fus
ion in backprop.
"""
@
staticmethod
def
forward
(
ctx
,
input
,
weight
,
bias
):
def
forward
(
ctx
,
input
,
weight
,
bias
,
gradient_accumulation_fusion
,
async_grad_allreduce
):
ctx
.
save_for_backward
(
input
,
weight
)
ctx
.
use_bias
=
bias
is
not
None
ctx
.
gradient_accumulation_fusion
=
gradient_accumulation_fusion
ctx
.
async_grad_allreduce
=
async_grad_allreduce
output
=
torch
.
matmul
(
input
,
weight
.
t
())
if
bias
is
not
None
:
output
=
output
+
bias
...
...
@@ -215,19 +218,32 @@ class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function):
@
staticmethod
def
backward
(
ctx
,
grad_output
):
import
fused_dense_cuda
input
,
weight
=
ctx
.
saved_tensors
use_bias
=
ctx
.
use_bias
grad_input
=
grad_output
.
matmul
(
weight
)
# Asyncronous all-reduce
handle
=
torch
.
distributed
.
all_reduce
(
grad_input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_
=
torch
.
empty
(
1
,
device
=
grad_output
.
device
)
+
1
grad_weight
=
grad_output
.
t
().
matmul
(
input
)
# Convert the tensor shapes to 2D for execution compatibility
grad_output
=
grad_output
.
view
(
grad_output
.
shape
[
0
]
*
grad_output
.
shape
[
1
],
grad_output
.
shape
[
2
])
input
=
input
.
view
(
input
.
shape
[
0
]
*
input
.
shape
[
1
],
input
.
shape
[
2
])
if
ctx
.
async_grad_allreduce
:
# Asynchronous all-reduce
handle
=
torch
.
distributed
.
all_reduce
(
grad_input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_
=
torch
.
empty
(
1
,
device
=
grad_output
.
device
)
+
1
if
ctx
.
gradient_accumulation_fusion
:
fused_dense_cuda
.
wgrad_gemm_accum_fp32
(
input
,
grad_output
,
weight
.
main_grad
)
grad_weight
=
None
else
:
grad_weight
=
grad_output
.
t
().
matmul
(
input
)
grad_bias
=
grad_output
.
sum
(
dim
=
0
)
if
use_bias
else
None
handle
.
wait
()
return
grad_input
,
grad_weight
,
grad_bias
if
ctx
.
async_grad_allreduce
:
handle
.
wait
()
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
class
ColumnParallelLinear
(
torch
.
nn
.
Module
):
...
...
@@ -240,7 +256,7 @@ class ColumnParallelLinear(torch.nn.Module):
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
bias: If true, add bias
gather_output: If true, call all-gather on output and make Y avaiable
gather_output: If true, call all-gather on output and make Y avai
l
able
to all GPUs, otherwise, every GPU will have its output
which is Y_i = XA_i
init_method: method to initialize weights. Note that bias is always set
...
...
@@ -305,29 +321,23 @@ class ColumnParallelLinear(torch.nn.Module):
else
:
self
.
register_parameter
(
'bias'
,
None
)
self
.
async_tensor_model_parallel_allreduce
=
(
not
args
.
no_
async_tensor_model_parallel_allreduce
and
args
.
async_tensor_model_parallel_allreduce
and
world_size
>
1
)
self
.
gradient_accumulation_fusion
=
args
.
gradient_accumulation_fusion
def
forward
(
self
,
input_
):
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
if
self
.
async_tensor_model_parallel_allreduce
:
input_shape
=
input_
.
shape
input_
=
input_
.
view
(
input_shape
[
0
]
*
input_shape
[
1
],
input_shape
[
2
])
# Maxtrix multiply with asynchronouse all-reduce execution
output_parallel
=
ColumnParallelLinearWithAsyncAllreduce
.
apply
(
input_
,
self
.
weight
,
bias
)
output_parallel
=
output_parallel
.
view
(
input_shape
[
0
],
input_shape
[
1
],
output_parallel
.
shape
[
1
])
input_parallel
=
input_
else
:
# Set up backprop all-reduce.
input_parallel
=
copy_to_tensor_model_parallel_region
(
input_
)
# Matrix
multiply
.
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
,
bias
)
# Matrix multiply.
output_parallel
=
LinearWithGradAccu
mul
a
ti
onAndAsyncAllreduce
.
ap
ply
(
input_parallel
,
self
.
weight
,
bias
,
self
.
gradient_accumulation_fusion
,
self
.
async_tensor_model_parallel_allreduce
)
if
self
.
gather_output
:
# All-gather across the partitions.
output
=
gather_from_tensor_model_parallel_region
(
output_parallel
)
...
...
@@ -415,7 +425,7 @@ class RowParallelLinear(torch.nn.Module):
self
.
bias
.
zero_
()
else
:
self
.
register_parameter
(
'bias'
,
None
)
self
.
gradient_accumulation_fusion
=
args
.
gradient_accumulation_fusion
def
forward
(
self
,
input_
):
...
...
@@ -425,7 +435,9 @@ class RowParallelLinear(torch.nn.Module):
else
:
input_parallel
=
scatter_to_tensor_model_parallel_region
(
input_
)
# Matrix multiply.
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
)
output_parallel
=
LinearWithGradAccumulationAndAsyncAllreduce
.
apply
(
input_parallel
,
self
.
weight
,
None
,
self
.
gradient_accumulation_fusion
,
None
)
# All-reduce across all the partitions.
output_
=
reduce_from_tensor_model_parallel_region
(
output_parallel
)
if
not
self
.
skip_bias_add
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment