Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
6fdbf26b
Commit
6fdbf26b
authored
Mar 07, 2022
by
Vijay Korthikanti
Browse files
Merge branch 'main' into allreduce-split
parents
0d77c0e9
53f3efc4
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
325 additions
and
20 deletions
+325
-20
megatron/arguments.py
megatron/arguments.py
+18
-5
megatron/fused_kernels/__init__.py
megatron/fused_kernels/__init__.py
+10
-0
megatron/fused_kernels/fused_weight_gradient_dense.cpp
megatron/fused_kernels/fused_weight_gradient_dense.cpp
+47
-0
megatron/fused_kernels/fused_weight_gradient_dense.cu
megatron/fused_kernels/fused_weight_gradient_dense.cu
+157
-0
megatron/fused_kernels/type_shim.h
megatron/fused_kernels/type_shim.h
+26
-0
megatron/model/distributed.py
megatron/model/distributed.py
+2
-2
megatron/model/language_model.py
megatron/model/language_model.py
+0
-1
megatron/model/transformer.py
megatron/model/transformer.py
+51
-2
megatron/mpu/__init__.py
megatron/mpu/__init__.py
+1
-0
megatron/mpu/layers.py
megatron/mpu/layers.py
+13
-10
No files found.
megatron/arguments.py
View file @
6fdbf26b
...
...
@@ -172,6 +172,14 @@ def parse_args(extra_args_provider=None, defaults={},
if
args
.
accumulate_allreduce_grads_in_fp32
:
assert
args
.
DDP_impl
==
'local'
assert
args
.
use_contiguous_buffers_in_local_ddp
else
:
if
args
.
gradient_accumulation_fusion
:
args
.
gradient_accumulation_fusion
=
False
if
args
.
rank
==
0
:
print
(
'Gradient accumulation fusion to linear layer weight '
'gradient computation is supported only with fp32 '
'gradient accumulation. Setting gradient_accumulation_fusion '
'to False'
,
flush
=
True
)
# For torch DDP, we do not use contiguous buffer
if
args
.
DDP_impl
==
'torch'
:
...
...
@@ -357,7 +365,8 @@ def _add_network_size_args(parser):
group
.
add_argument
(
'--bert-no-binary-head'
,
action
=
'store_false'
,
help
=
'Disable BERT binary head.'
,
dest
=
'bert_binary_head'
)
group
.
add_argument
(
'--num-experts'
,
type
=
int
,
default
=
None
,
help
=
'Number of Experts in Switch Transformer (None means no Switch)'
)
return
parser
...
...
@@ -521,10 +530,11 @@ def _add_training_args(parser):
choices
=
[
'single'
,
'cyclic'
],
help
=
'Single pass vs multiple pass data loader'
)
group
.
add_argument
(
'--no-async-tensor-model-parallel-allreduce'
,
action
=
'store_
tru
e'
,
action
=
'store_
fals
e'
,
help
=
'Disable asynchronous execution of '
'tensor-model-parallel all-reduce with weight '
'gradient compuation of a column-linear layer.'
)
'gradient compuation of a column-linear layer.'
,
dest
=
'async_tensor_model_parallel_allreduce'
)
group
.
add_argument
(
'--no-persist-layer-norm'
,
action
=
'store_true'
,
help
=
'Disable using persistent fused layer norm kernel. '
'This kernel supports only a set of hidden sizes. Please '
...
...
@@ -532,8 +542,11 @@ def _add_training_args(parser):
'size is supported.'
)
group
.
add_argument
(
'--model-parallel-memory-opt'
,
action
=
'store_true'
,
help
=
'Enable model parallel memory optmization.'
)
group
.
add_argument
(
'--no-gradient-accumulation-fusion'
,
action
=
'store_false'
,
help
=
'Disable fusing gradient accumulation to weight '
'gradient computation of linear layers'
,
dest
=
'gradient_accumulation_fusion'
)
return
parser
...
...
megatron/fused_kernels/__init__.py
View file @
6fdbf26b
...
...
@@ -94,6 +94,16 @@ def load(args):
fused_mix_prec_layer_norm_cuda
=
_cpp_extention_load_helper
(
"fused_mix_prec_layer_norm_cuda"
,
sources
,
extra_cuda_flags
)
# =================================
# Fused gradient accumulation to weight gradient computation of linear layer
# =================================
if
args
.
gradient_accumulation_fusion
:
sources
=
[
srcpath
/
'fused_weight_gradient_dense.cpp'
,
srcpath
/
'fused_weight_gradient_dense.cu'
]
fused_dense_cuda
=
_cpp_extention_load_helper
(
"fused_dense_cuda"
,
sources
,
[])
def
_get_cuda_bare_metal_version
(
cuda_dir
):
raw_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
...
...
megatron/fused_kernels/fused_weight_gradient_dense.cpp
0 → 100644
View file @
6fdbf26b
#include <torch/torch.h>
#include <torch/extension.h>
#include <vector>
#include <stdio.h>
#include "type_shim.h"
template
<
typename
T
>
int
wgrad_gemm_accum_fp32_cuda
(
T
*
input
,
T
*
d_output
,
float
*
d_weight
,
int
in_dim
,
int
hidden_dim
,
int
out_dim
);
void
wgrad_gemm_accum_fp32
(
const
at
::
Tensor
input
,
const
at
::
Tensor
d_output
,
at
::
Tensor
d_weight
)
{
at
::
Tensor
input_2d
,
d_output_2d
;
// input tensor: collapse to the first dim
auto
in_sizes
=
input
.
sizes
();
if
(
input
.
dim
()
>
2
)
{
input_2d
=
input
.
view
({
-
1
,
in_sizes
[
in_sizes
.
size
()
-
1
]});
}
else
{
input_2d
=
input
;
}
// d_output tensor: collapse to the first dim
auto
d_out_sizes
=
d_output
.
sizes
();
if
(
d_output
.
dim
()
>
2
)
{
d_output_2d
=
d_output
.
view
({
-
1
,
d_out_sizes
[
d_out_sizes
.
size
()
-
1
]});
}
else
{
d_output_2d
=
d_output
;
}
int
hidden_dim
=
input_2d
.
size
(
0
);
int
in_dim
=
input_2d
.
size
(
1
);
int
out_dim
=
d_weight
.
size
(
0
);
DISPATCH_HALF_BFLOAT_AND_FLOAT
(
input_2d
.
scalar_type
(),
"wgrad_gemm_accum_fp32"
,
int
result
=
wgrad_gemm_accum_fp32_cuda
<
scalar_t
>
(
input_2d
.
data_ptr
<
scalar_t
>
(),
d_output_2d
.
data_ptr
<
scalar_t
>
(),
d_weight
.
data_ptr
<
float
>
(),
in_dim
,
hidden_dim
,
out_dim
);
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"wgrad_gemm_accum_fp32"
,
&
wgrad_gemm_accum_fp32
,
"wgrad gemm accum in fp32"
);
}
megatron/fused_kernels/fused_weight_gradient_dense.cu
0 → 100644
View file @
6fdbf26b
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <torch/torch.h>
/* Includes, cuda */
#include <cublas_v2.h>
#include <cuda_runtime.h>
// BF16 Tensor core wrapper around cublas GEMMEx
cublasStatus_t
gemmex_wrapper
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
at
::
BFloat16
*
A
,
int
lda
,
at
::
BFloat16
*
B
,
int
ldb
,
const
float
*
beta
,
float
*
C
,
int
ldc
)
{
return
cublasGemmEx
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
CUDA_R_16BF
,
lda
,
B
,
CUDA_R_16BF
,
ldb
,
beta
,
C
,
CUDA_R_32F
,
ldc
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
}
// FP16 Tensor core wrapper around cublas GEMMEx
cublasStatus_t
gemmex_wrapper
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
at
::
Half
*
A
,
int
lda
,
at
::
Half
*
B
,
int
ldb
,
const
float
*
beta
,
float
*
C
,
int
ldc
)
{
return
cublasGemmEx
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
CUDA_R_16F
,
lda
,
B
,
CUDA_R_16F
,
ldb
,
beta
,
C
,
CUDA_R_32F
,
ldc
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
}
// FP32 Tensor core wrapper around cublas GEMMEx
cublasStatus_t
gemmex_wrapper
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
float
*
A
,
int
lda
,
float
*
B
,
int
ldb
,
const
float
*
beta
,
float
*
C
,
int
ldc
)
{
return
cublasGemmEx
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
CUDA_R_32F
,
lda
,
B
,
CUDA_R_32F
,
ldb
,
beta
,
C
,
CUDA_R_32F
,
ldc
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
}
template
<
typename
T
>
int
wgrad_gemm_accum_fp32_cuda
(
T
*
input
,
T
*
d_output
,
float
*
d_weight
,
int
in_dim
,
int
hidden_dim
,
int
out_dim
)
{
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
cudaStream_t
stream
;
cublasGetStream
(
handle
,
&
stream
);
const
float
alpha
=
1.0
;
const
float
beta
=
1.0
;
int
status
=
1
;
status
=
gemmex_wrapper
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
in_dim
,
out_dim
,
hidden_dim
,
&
alpha
,
input
,
in_dim
,
d_output
,
out_dim
,
&
beta
,
d_weight
,
in_dim
);
return
status
;
}
template
int
wgrad_gemm_accum_fp32_cuda
<
at
::
Half
>(
at
::
Half
*
input
,
at
::
Half
*
d_output
,
float
*
d_weight
,
int
in_dim
,
int
hidden_dim
,
int
out_dim
);
template
int
wgrad_gemm_accum_fp32_cuda
<
at
::
BFloat16
>(
at
::
BFloat16
*
input
,
at
::
BFloat16
*
d_output
,
float
*
d_weight
,
int
in_dim
,
int
hidden_dim
,
int
out_dim
);
template
int
wgrad_gemm_accum_fp32_cuda
<
float
>(
float
*
input
,
float
*
d_output
,
float
*
d_weight
,
int
in_dim
,
int
hidden_dim
,
int
out_dim
);
megatron/fused_kernels/type_shim.h
View file @
6fdbf26b
...
...
@@ -39,6 +39,32 @@
}
#define DISPATCH_HALF_BFLOAT_AND_FLOAT(TYPE, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Half: \
{ \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch(TYPEIN) \
...
...
megatron/model/distributed.py
View file @
6fdbf26b
...
...
@@ -164,13 +164,13 @@ class DistributedDataParallel(DistributedDataParallelBase):
grad_acc
.
register_hook
(
self
.
_make_param_hook
(
param
))
self
.
grad_accs
.
append
(
grad_acc
)
def
_make_param_hook
(
self
,
param
):
"""Create the all-reduce hook for backprop."""
# Hook used for back-prop.
def
param_hook
(
*
unused
):
# Add the gradient to the buffer.
if
param
.
grad
.
data
is
not
None
:
if
param
.
grad
is
not
None
:
# The gradient function of linear layers is fused with GEMMs
param
.
main_grad
.
add_
(
param
.
grad
.
data
)
# Now we can deallocate grad memory.
param
.
grad
=
None
...
...
megatron/model/language_model.py
View file @
6fdbf26b
...
...
@@ -31,7 +31,6 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
bias
=
None
):
"""LM logits using word embedding weights."""
args
=
get_args
()
# Parallel logits.
if
args
.
async_tensor_model_parallel_allreduce
or
\
args
.
model_parallel_memory_opt
:
...
...
megatron/model/transformer.py
View file @
6fdbf26b
...
...
@@ -116,6 +116,53 @@ class ParallelMLP(MegatronModule):
output
,
output_bias
=
self
.
dense_4h_to_h
(
intermediate_parallel
)
return
output
,
output_bias
class
SwitchMLP
(
MegatronModule
):
"""
Routes input to one of N MLP "experts"
"""
def
__init__
(
self
,
init_method
,
output_layer_init_method
):
super
(
SwitchMLP
,
self
).
__init__
()
args
=
get_args
()
self
.
router
=
torch
.
nn
.
Linear
(
args
.
hidden_size
,
args
.
num_experts
)
self
.
experts
=
torch
.
nn
.
ModuleList
()
for
i
in
range
(
args
.
num_experts
):
self
.
experts
.
append
(
ParallelMLP
(
init_method
,
output_layer_init_method
))
def
forward
(
self
,
hidden_states
):
# hidden_states: [b, s, h]
b
=
hidden_states
.
size
(
0
)
s
=
hidden_states
.
size
(
1
)
h
=
hidden_states
.
size
(
2
)
route
=
self
.
router
(
hidden_states
)
route
=
torch
.
nn
.
functional
.
softmax
(
route
,
dim
=
2
)
max_prob
,
max_ind
=
torch
.
max
(
route
,
dim
=
2
)
max_prob
=
torch
.
unsqueeze
(
max_prob
,
2
)
# [b s 1]
# TODO (rprenger) TODO this could be made easier to read
# Converting [b, s, h] to [b*s, h].
# Each vector could be routed differently
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_states
.
size
(
2
))
# [b*s h]
max_prob
=
max_prob
.
view
(
-
1
,
max_prob
.
size
(
2
))
# [b*s 1]
max_ind
=
max_ind
.
view
(
-
1
)
# [b*s]
output_total
=
torch
.
empty_like
(
hidden_states
)
output_bias_total
=
torch
.
empty_like
(
hidden_states
)
#TODO (rprenger) This does each expert in serial, but it could be parallelized
for
expert_num
,
expert
in
enumerate
(
self
.
experts
):
local_indices
=
(
max_ind
==
expert_num
).
nonzero
()
hidden
=
hidden_states
[
local_indices
,:]
output
,
output_bias
=
expert
(
hidden
)
output_bias
=
output_bias
.
expand_as
(
output
)
output_total
[
local_indices
,:]
=
output
output_bias_total
[
local_indices
,:]
=
output_bias
output_total
=
output_total
*
max_prob
output_bias_total
=
output_bias_total
*
max_prob
output_total
=
output_total
.
view
(
b
,
s
,
h
)
output_bias_total
=
output_bias_total
.
view
(
b
,
s
,
h
)
return
output_total
,
output_bias_total
class
ParallelAttention
(
MegatronModule
):
"""Parallel self-attention layer abstract class.
...
...
@@ -482,8 +529,10 @@ class ParallelTransformerLayer(MegatronModule):
sequence_parallel
=
args
.
model_parallel_memory_opt
)
# MLP
self
.
mlp
=
ParallelMLP
(
init_method
,
output_layer_init_method
)
if
args
.
num_experts
is
not
None
:
self
.
mlp
=
SwitchMLP
(
init_method
,
output_layer_init_method
)
else
:
self
.
mlp
=
ParallelMLP
(
init_method
,
output_layer_init_method
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
...
...
megatron/mpu/__init__.py
View file @
6fdbf26b
...
...
@@ -49,6 +49,7 @@ from .initialize import get_virtual_pipeline_model_parallel_rank, set_virtual_pi
from
.initialize
import
initialize_model_parallel
from
.initialize
import
model_parallel_is_initialized
from
.layers
import
LinearWithGradAccumulationAndAsyncAllreduce
from
.layers
import
ColumnParallelLinear
from
.layers
import
RowParallelLinear
from
.layers
import
VocabParallelEmbedding
...
...
megatron/mpu/layers.py
View file @
6fdbf26b
...
...
@@ -237,6 +237,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
@
staticmethod
def
backward
(
ctx
,
grad_output
):
import
fused_dense_cuda
input
,
weight
=
ctx
.
saved_tensors
use_bias
=
ctx
.
use_bias
...
...
@@ -280,12 +281,10 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
sub_grad_input
=
torch
.
empty
(
dim_size
,
dtype
=
input
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
# reduce_scatter
handle
=
torch
.
distributed
.
_reduce_scatter_base
(
sub_grad_input
,
grad_input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
# Delay the start of weight gradient computation shortly (3us) to have
# reduce scatter scheduled first and have GPU resources allocated
_
=
torch
.
empty
(
1
,
device
=
grad_output
.
device
)
+
1
...
...
@@ -298,14 +297,15 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
grad_weight
=
grad_output
.
t
().
matmul
(
total_input
)
grad_bias
=
grad_output
.
sum
(
dim
=
0
)
if
use_bias
else
None
if
ctx
.
async_grad_allreducei
:
handle
.
wait
()
return
grad_input
,
grad_weight
,
grad_bias
if
ctx
.
model_parallel_memory_opt
:
handle
.
wait
()
return
sub_grad_input
,
grad_weight
,
grad_bias
if
ctx
.
async_grad_allreduce
:
handle
.
wait
()
return
grad_input
,
grad_weight
,
grad_bias
class
ColumnParallelLinear
(
torch
.
nn
.
Module
):
"""Linear layer with column parallelism.
...
...
@@ -317,7 +317,7 @@ class ColumnParallelLinear(torch.nn.Module):
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
bias: If true, add bias
gather_output: If true, call all-gather on output and make Y avaiable
gather_output: If true, call all-gather on output and make Y avai
l
able
to all GPUs, otherwise, every GPU will have its output
which is Y_i = XA_i
init_method: method to initialize weights. Note that bias is always set
...
...
@@ -382,13 +382,14 @@ class ColumnParallelLinear(torch.nn.Module):
else
:
self
.
register_parameter
(
'bias'
,
None
)
self
.
async_tensor_model_parallel_allreduce
=
(
not
args
.
no_
async_tensor_model_parallel_allreduce
and
args
.
async_tensor_model_parallel_allreduce
and
world_size
>
1
)
self
.
model_parallel_memory_opt
=
(
args
.
model_parallel_memory_opt
and
world_size
>
1
)
assert
not
self
.
async_tensor_model_parallel_allreduce
or
\
not
self
.
model_parallel_memory_opt
self
.
gradient_accumulation_fusion
=
args
.
gradient_accumulation_fusion
def
forward
(
self
,
input_
):
...
...
@@ -491,8 +492,8 @@ class RowParallelLinear(torch.nn.Module):
self
.
bias
.
zero_
()
else
:
self
.
register_parameter
(
'bias'
,
None
)
self
.
model_parallel_memory_opt
=
args
.
model_parallel_memory_opt
self
.
gradient_accumulation_fusion
=
args
.
gradient_accumulation_fusion
def
forward
(
self
,
input_
):
...
...
@@ -503,7 +504,9 @@ class RowParallelLinear(torch.nn.Module):
assert
not
self
.
model_parallel_memory_opt
input_parallel
=
scatter_to_tensor_model_parallel_region
(
input_
)
# Matrix multiply.
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
)
output_parallel
=
LinearWithGradAccumulationAndAsyncAllreduce
.
apply
(
input_parallel
,
self
.
weight
,
None
,
self
.
gradient_accumulation_fusion
,
None
)
# All-reduce across all the partitions.
if
self
.
model_parallel_memory_opt
:
output_
=
reduce_scatter_to_sequence_parallel_region
(
output_parallel
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment