Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
6fdbf26b
Commit
6fdbf26b
authored
Mar 07, 2022
by
Vijay Korthikanti
Browse files
Merge branch 'main' into allreduce-split
parents
0d77c0e9
53f3efc4
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
325 additions
and
20 deletions
+325
-20
megatron/arguments.py
megatron/arguments.py
+18
-5
megatron/fused_kernels/__init__.py
megatron/fused_kernels/__init__.py
+10
-0
megatron/fused_kernels/fused_weight_gradient_dense.cpp
megatron/fused_kernels/fused_weight_gradient_dense.cpp
+47
-0
megatron/fused_kernels/fused_weight_gradient_dense.cu
megatron/fused_kernels/fused_weight_gradient_dense.cu
+157
-0
megatron/fused_kernels/type_shim.h
megatron/fused_kernels/type_shim.h
+26
-0
megatron/model/distributed.py
megatron/model/distributed.py
+2
-2
megatron/model/language_model.py
megatron/model/language_model.py
+0
-1
megatron/model/transformer.py
megatron/model/transformer.py
+51
-2
megatron/mpu/__init__.py
megatron/mpu/__init__.py
+1
-0
megatron/mpu/layers.py
megatron/mpu/layers.py
+13
-10
No files found.
megatron/arguments.py
View file @
6fdbf26b
...
@@ -172,6 +172,14 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -172,6 +172,14 @@ def parse_args(extra_args_provider=None, defaults={},
if
args
.
accumulate_allreduce_grads_in_fp32
:
if
args
.
accumulate_allreduce_grads_in_fp32
:
assert
args
.
DDP_impl
==
'local'
assert
args
.
DDP_impl
==
'local'
assert
args
.
use_contiguous_buffers_in_local_ddp
assert
args
.
use_contiguous_buffers_in_local_ddp
else
:
if
args
.
gradient_accumulation_fusion
:
args
.
gradient_accumulation_fusion
=
False
if
args
.
rank
==
0
:
print
(
'Gradient accumulation fusion to linear layer weight '
'gradient computation is supported only with fp32 '
'gradient accumulation. Setting gradient_accumulation_fusion '
'to False'
,
flush
=
True
)
# For torch DDP, we do not use contiguous buffer
# For torch DDP, we do not use contiguous buffer
if
args
.
DDP_impl
==
'torch'
:
if
args
.
DDP_impl
==
'torch'
:
...
@@ -357,7 +365,8 @@ def _add_network_size_args(parser):
...
@@ -357,7 +365,8 @@ def _add_network_size_args(parser):
group
.
add_argument
(
'--bert-no-binary-head'
,
action
=
'store_false'
,
group
.
add_argument
(
'--bert-no-binary-head'
,
action
=
'store_false'
,
help
=
'Disable BERT binary head.'
,
help
=
'Disable BERT binary head.'
,
dest
=
'bert_binary_head'
)
dest
=
'bert_binary_head'
)
group
.
add_argument
(
'--num-experts'
,
type
=
int
,
default
=
None
,
help
=
'Number of Experts in Switch Transformer (None means no Switch)'
)
return
parser
return
parser
...
@@ -521,10 +530,11 @@ def _add_training_args(parser):
...
@@ -521,10 +530,11 @@ def _add_training_args(parser):
choices
=
[
'single'
,
'cyclic'
],
choices
=
[
'single'
,
'cyclic'
],
help
=
'Single pass vs multiple pass data loader'
)
help
=
'Single pass vs multiple pass data loader'
)
group
.
add_argument
(
'--no-async-tensor-model-parallel-allreduce'
,
group
.
add_argument
(
'--no-async-tensor-model-parallel-allreduce'
,
action
=
'store_
tru
e'
,
action
=
'store_
fals
e'
,
help
=
'Disable asynchronous execution of '
help
=
'Disable asynchronous execution of '
'tensor-model-parallel all-reduce with weight '
'tensor-model-parallel all-reduce with weight '
'gradient compuation of a column-linear layer.'
)
'gradient compuation of a column-linear layer.'
,
dest
=
'async_tensor_model_parallel_allreduce'
)
group
.
add_argument
(
'--no-persist-layer-norm'
,
action
=
'store_true'
,
group
.
add_argument
(
'--no-persist-layer-norm'
,
action
=
'store_true'
,
help
=
'Disable using persistent fused layer norm kernel. '
help
=
'Disable using persistent fused layer norm kernel. '
'This kernel supports only a set of hidden sizes. Please '
'This kernel supports only a set of hidden sizes. Please '
...
@@ -532,8 +542,11 @@ def _add_training_args(parser):
...
@@ -532,8 +542,11 @@ def _add_training_args(parser):
'size is supported.'
)
'size is supported.'
)
group
.
add_argument
(
'--model-parallel-memory-opt'
,
action
=
'store_true'
,
group
.
add_argument
(
'--model-parallel-memory-opt'
,
action
=
'store_true'
,
help
=
'Enable model parallel memory optmization.'
)
help
=
'Enable model parallel memory optmization.'
)
group
.
add_argument
(
'--no-gradient-accumulation-fusion'
,
action
=
'store_false'
,
help
=
'Disable fusing gradient accumulation to weight '
'gradient computation of linear layers'
,
dest
=
'gradient_accumulation_fusion'
)
return
parser
return
parser
...
...
megatron/fused_kernels/__init__.py
View file @
6fdbf26b
...
@@ -94,6 +94,16 @@ def load(args):
...
@@ -94,6 +94,16 @@ def load(args):
fused_mix_prec_layer_norm_cuda
=
_cpp_extention_load_helper
(
fused_mix_prec_layer_norm_cuda
=
_cpp_extention_load_helper
(
"fused_mix_prec_layer_norm_cuda"
,
sources
,
extra_cuda_flags
)
"fused_mix_prec_layer_norm_cuda"
,
sources
,
extra_cuda_flags
)
# =================================
# Fused gradient accumulation to weight gradient computation of linear layer
# =================================
if
args
.
gradient_accumulation_fusion
:
sources
=
[
srcpath
/
'fused_weight_gradient_dense.cpp'
,
srcpath
/
'fused_weight_gradient_dense.cu'
]
fused_dense_cuda
=
_cpp_extention_load_helper
(
"fused_dense_cuda"
,
sources
,
[])
def
_get_cuda_bare_metal_version
(
cuda_dir
):
def
_get_cuda_bare_metal_version
(
cuda_dir
):
raw_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
raw_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
...
...
megatron/fused_kernels/fused_weight_gradient_dense.cpp
0 → 100644
View file @
6fdbf26b
#include <torch/torch.h>
#include <torch/extension.h>
#include <vector>
#include <stdio.h>
#include "type_shim.h"
template
<
typename
T
>
int
wgrad_gemm_accum_fp32_cuda
(
T
*
input
,
T
*
d_output
,
float
*
d_weight
,
int
in_dim
,
int
hidden_dim
,
int
out_dim
);
void
wgrad_gemm_accum_fp32
(
const
at
::
Tensor
input
,
const
at
::
Tensor
d_output
,
at
::
Tensor
d_weight
)
{
at
::
Tensor
input_2d
,
d_output_2d
;
// input tensor: collapse to the first dim
auto
in_sizes
=
input
.
sizes
();
if
(
input
.
dim
()
>
2
)
{
input_2d
=
input
.
view
({
-
1
,
in_sizes
[
in_sizes
.
size
()
-
1
]});
}
else
{
input_2d
=
input
;
}
// d_output tensor: collapse to the first dim
auto
d_out_sizes
=
d_output
.
sizes
();
if
(
d_output
.
dim
()
>
2
)
{
d_output_2d
=
d_output
.
view
({
-
1
,
d_out_sizes
[
d_out_sizes
.
size
()
-
1
]});
}
else
{
d_output_2d
=
d_output
;
}
int
hidden_dim
=
input_2d
.
size
(
0
);
int
in_dim
=
input_2d
.
size
(
1
);
int
out_dim
=
d_weight
.
size
(
0
);
DISPATCH_HALF_BFLOAT_AND_FLOAT
(
input_2d
.
scalar_type
(),
"wgrad_gemm_accum_fp32"
,
int
result
=
wgrad_gemm_accum_fp32_cuda
<
scalar_t
>
(
input_2d
.
data_ptr
<
scalar_t
>
(),
d_output_2d
.
data_ptr
<
scalar_t
>
(),
d_weight
.
data_ptr
<
float
>
(),
in_dim
,
hidden_dim
,
out_dim
);
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"wgrad_gemm_accum_fp32"
,
&
wgrad_gemm_accum_fp32
,
"wgrad gemm accum in fp32"
);
}
megatron/fused_kernels/fused_weight_gradient_dense.cu
0 → 100644
View file @
6fdbf26b
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <torch/torch.h>
/* Includes, cuda */
#include <cublas_v2.h>
#include <cuda_runtime.h>
// BF16 Tensor core wrapper around cublas GEMMEx
cublasStatus_t
gemmex_wrapper
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
at
::
BFloat16
*
A
,
int
lda
,
at
::
BFloat16
*
B
,
int
ldb
,
const
float
*
beta
,
float
*
C
,
int
ldc
)
{
return
cublasGemmEx
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
CUDA_R_16BF
,
lda
,
B
,
CUDA_R_16BF
,
ldb
,
beta
,
C
,
CUDA_R_32F
,
ldc
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
}
// FP16 Tensor core wrapper around cublas GEMMEx
cublasStatus_t
gemmex_wrapper
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
at
::
Half
*
A
,
int
lda
,
at
::
Half
*
B
,
int
ldb
,
const
float
*
beta
,
float
*
C
,
int
ldc
)
{
return
cublasGemmEx
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
CUDA_R_16F
,
lda
,
B
,
CUDA_R_16F
,
ldb
,
beta
,
C
,
CUDA_R_32F
,
ldc
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
}
// FP32 Tensor core wrapper around cublas GEMMEx
cublasStatus_t
gemmex_wrapper
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
float
*
A
,
int
lda
,
float
*
B
,
int
ldb
,
const
float
*
beta
,
float
*
C
,
int
ldc
)
{
return
cublasGemmEx
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
CUDA_R_32F
,
lda
,
B
,
CUDA_R_32F
,
ldb
,
beta
,
C
,
CUDA_R_32F
,
ldc
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
}
template
<
typename
T
>
int
wgrad_gemm_accum_fp32_cuda
(
T
*
input
,
T
*
d_output
,
float
*
d_weight
,
int
in_dim
,
int
hidden_dim
,
int
out_dim
)
{
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
cudaStream_t
stream
;
cublasGetStream
(
handle
,
&
stream
);
const
float
alpha
=
1.0
;
const
float
beta
=
1.0
;
int
status
=
1
;
status
=
gemmex_wrapper
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
in_dim
,
out_dim
,
hidden_dim
,
&
alpha
,
input
,
in_dim
,
d_output
,
out_dim
,
&
beta
,
d_weight
,
in_dim
);
return
status
;
}
template
int
wgrad_gemm_accum_fp32_cuda
<
at
::
Half
>(
at
::
Half
*
input
,
at
::
Half
*
d_output
,
float
*
d_weight
,
int
in_dim
,
int
hidden_dim
,
int
out_dim
);
template
int
wgrad_gemm_accum_fp32_cuda
<
at
::
BFloat16
>(
at
::
BFloat16
*
input
,
at
::
BFloat16
*
d_output
,
float
*
d_weight
,
int
in_dim
,
int
hidden_dim
,
int
out_dim
);
template
int
wgrad_gemm_accum_fp32_cuda
<
float
>(
float
*
input
,
float
*
d_output
,
float
*
d_weight
,
int
in_dim
,
int
hidden_dim
,
int
out_dim
);
megatron/fused_kernels/type_shim.h
View file @
6fdbf26b
...
@@ -39,6 +39,32 @@
...
@@ -39,6 +39,32 @@
}
}
#define DISPATCH_HALF_BFLOAT_AND_FLOAT(TYPE, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Half: \
{ \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch(TYPEIN) \
switch(TYPEIN) \
...
...
megatron/model/distributed.py
View file @
6fdbf26b
...
@@ -164,13 +164,13 @@ class DistributedDataParallel(DistributedDataParallelBase):
...
@@ -164,13 +164,13 @@ class DistributedDataParallel(DistributedDataParallelBase):
grad_acc
.
register_hook
(
self
.
_make_param_hook
(
param
))
grad_acc
.
register_hook
(
self
.
_make_param_hook
(
param
))
self
.
grad_accs
.
append
(
grad_acc
)
self
.
grad_accs
.
append
(
grad_acc
)
def
_make_param_hook
(
self
,
param
):
def
_make_param_hook
(
self
,
param
):
"""Create the all-reduce hook for backprop."""
"""Create the all-reduce hook for backprop."""
# Hook used for back-prop.
# Hook used for back-prop.
def
param_hook
(
*
unused
):
def
param_hook
(
*
unused
):
# Add the gradient to the buffer.
# Add the gradient to the buffer.
if
param
.
grad
.
data
is
not
None
:
if
param
.
grad
is
not
None
:
# The gradient function of linear layers is fused with GEMMs
param
.
main_grad
.
add_
(
param
.
grad
.
data
)
param
.
main_grad
.
add_
(
param
.
grad
.
data
)
# Now we can deallocate grad memory.
# Now we can deallocate grad memory.
param
.
grad
=
None
param
.
grad
=
None
...
...
megatron/model/language_model.py
View file @
6fdbf26b
...
@@ -31,7 +31,6 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
...
@@ -31,7 +31,6 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
bias
=
None
):
bias
=
None
):
"""LM logits using word embedding weights."""
"""LM logits using word embedding weights."""
args
=
get_args
()
args
=
get_args
()
# Parallel logits.
# Parallel logits.
if
args
.
async_tensor_model_parallel_allreduce
or
\
if
args
.
async_tensor_model_parallel_allreduce
or
\
args
.
model_parallel_memory_opt
:
args
.
model_parallel_memory_opt
:
...
...
megatron/model/transformer.py
View file @
6fdbf26b
...
@@ -116,6 +116,53 @@ class ParallelMLP(MegatronModule):
...
@@ -116,6 +116,53 @@ class ParallelMLP(MegatronModule):
output
,
output_bias
=
self
.
dense_4h_to_h
(
intermediate_parallel
)
output
,
output_bias
=
self
.
dense_4h_to_h
(
intermediate_parallel
)
return
output
,
output_bias
return
output
,
output_bias
class
SwitchMLP
(
MegatronModule
):
"""
Routes input to one of N MLP "experts"
"""
def
__init__
(
self
,
init_method
,
output_layer_init_method
):
super
(
SwitchMLP
,
self
).
__init__
()
args
=
get_args
()
self
.
router
=
torch
.
nn
.
Linear
(
args
.
hidden_size
,
args
.
num_experts
)
self
.
experts
=
torch
.
nn
.
ModuleList
()
for
i
in
range
(
args
.
num_experts
):
self
.
experts
.
append
(
ParallelMLP
(
init_method
,
output_layer_init_method
))
def
forward
(
self
,
hidden_states
):
# hidden_states: [b, s, h]
b
=
hidden_states
.
size
(
0
)
s
=
hidden_states
.
size
(
1
)
h
=
hidden_states
.
size
(
2
)
route
=
self
.
router
(
hidden_states
)
route
=
torch
.
nn
.
functional
.
softmax
(
route
,
dim
=
2
)
max_prob
,
max_ind
=
torch
.
max
(
route
,
dim
=
2
)
max_prob
=
torch
.
unsqueeze
(
max_prob
,
2
)
# [b s 1]
# TODO (rprenger) TODO this could be made easier to read
# Converting [b, s, h] to [b*s, h].
# Each vector could be routed differently
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_states
.
size
(
2
))
# [b*s h]
max_prob
=
max_prob
.
view
(
-
1
,
max_prob
.
size
(
2
))
# [b*s 1]
max_ind
=
max_ind
.
view
(
-
1
)
# [b*s]
output_total
=
torch
.
empty_like
(
hidden_states
)
output_bias_total
=
torch
.
empty_like
(
hidden_states
)
#TODO (rprenger) This does each expert in serial, but it could be parallelized
for
expert_num
,
expert
in
enumerate
(
self
.
experts
):
local_indices
=
(
max_ind
==
expert_num
).
nonzero
()
hidden
=
hidden_states
[
local_indices
,:]
output
,
output_bias
=
expert
(
hidden
)
output_bias
=
output_bias
.
expand_as
(
output
)
output_total
[
local_indices
,:]
=
output
output_bias_total
[
local_indices
,:]
=
output_bias
output_total
=
output_total
*
max_prob
output_bias_total
=
output_bias_total
*
max_prob
output_total
=
output_total
.
view
(
b
,
s
,
h
)
output_bias_total
=
output_bias_total
.
view
(
b
,
s
,
h
)
return
output_total
,
output_bias_total
class
ParallelAttention
(
MegatronModule
):
class
ParallelAttention
(
MegatronModule
):
"""Parallel self-attention layer abstract class.
"""Parallel self-attention layer abstract class.
...
@@ -482,8 +529,10 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -482,8 +529,10 @@ class ParallelTransformerLayer(MegatronModule):
sequence_parallel
=
args
.
model_parallel_memory_opt
)
sequence_parallel
=
args
.
model_parallel_memory_opt
)
# MLP
# MLP
self
.
mlp
=
ParallelMLP
(
init_method
,
if
args
.
num_experts
is
not
None
:
output_layer_init_method
)
self
.
mlp
=
SwitchMLP
(
init_method
,
output_layer_init_method
)
else
:
self
.
mlp
=
ParallelMLP
(
init_method
,
output_layer_init_method
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
...
...
megatron/mpu/__init__.py
View file @
6fdbf26b
...
@@ -49,6 +49,7 @@ from .initialize import get_virtual_pipeline_model_parallel_rank, set_virtual_pi
...
@@ -49,6 +49,7 @@ from .initialize import get_virtual_pipeline_model_parallel_rank, set_virtual_pi
from
.initialize
import
initialize_model_parallel
from
.initialize
import
initialize_model_parallel
from
.initialize
import
model_parallel_is_initialized
from
.initialize
import
model_parallel_is_initialized
from
.layers
import
LinearWithGradAccumulationAndAsyncAllreduce
from
.layers
import
ColumnParallelLinear
from
.layers
import
ColumnParallelLinear
from
.layers
import
RowParallelLinear
from
.layers
import
RowParallelLinear
from
.layers
import
VocabParallelEmbedding
from
.layers
import
VocabParallelEmbedding
...
...
megatron/mpu/layers.py
View file @
6fdbf26b
...
@@ -237,6 +237,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
...
@@ -237,6 +237,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_output
):
def
backward
(
ctx
,
grad_output
):
import
fused_dense_cuda
input
,
weight
=
ctx
.
saved_tensors
input
,
weight
=
ctx
.
saved_tensors
use_bias
=
ctx
.
use_bias
use_bias
=
ctx
.
use_bias
...
@@ -280,12 +281,10 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
...
@@ -280,12 +281,10 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
sub_grad_input
=
torch
.
empty
(
dim_size
,
dtype
=
input
.
dtype
,
sub_grad_input
=
torch
.
empty
(
dim_size
,
dtype
=
input
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
requires_grad
=
False
)
# reduce_scatter
# reduce_scatter
handle
=
torch
.
distributed
.
_reduce_scatter_base
(
sub_grad_input
,
grad_input
,
handle
=
torch
.
distributed
.
_reduce_scatter_base
(
sub_grad_input
,
grad_input
,
group
=
get_tensor_model_parallel_group
(),
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
async_op
=
True
)
# Delay the start of weight gradient computation shortly (3us) to have
# Delay the start of weight gradient computation shortly (3us) to have
# reduce scatter scheduled first and have GPU resources allocated
# reduce scatter scheduled first and have GPU resources allocated
_
=
torch
.
empty
(
1
,
device
=
grad_output
.
device
)
+
1
_
=
torch
.
empty
(
1
,
device
=
grad_output
.
device
)
+
1
...
@@ -298,14 +297,15 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
...
@@ -298,14 +297,15 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
grad_weight
=
grad_output
.
t
().
matmul
(
total_input
)
grad_weight
=
grad_output
.
t
().
matmul
(
total_input
)
grad_bias
=
grad_output
.
sum
(
dim
=
0
)
if
use_bias
else
None
grad_bias
=
grad_output
.
sum
(
dim
=
0
)
if
use_bias
else
None
if
ctx
.
async_grad_allreducei
:
handle
.
wait
()
return
grad_input
,
grad_weight
,
grad_bias
if
ctx
.
model_parallel_memory_opt
:
if
ctx
.
model_parallel_memory_opt
:
handle
.
wait
()
handle
.
wait
()
return
sub_grad_input
,
grad_weight
,
grad_bias
return
sub_grad_input
,
grad_weight
,
grad_bias
if
ctx
.
async_grad_allreduce
:
handle
.
wait
()
return
grad_input
,
grad_weight
,
grad_bias
class
ColumnParallelLinear
(
torch
.
nn
.
Module
):
class
ColumnParallelLinear
(
torch
.
nn
.
Module
):
"""Linear layer with column parallelism.
"""Linear layer with column parallelism.
...
@@ -317,7 +317,7 @@ class ColumnParallelLinear(torch.nn.Module):
...
@@ -317,7 +317,7 @@ class ColumnParallelLinear(torch.nn.Module):
input_size: first dimension of matrix A.
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
output_size: second dimension of matrix A.
bias: If true, add bias
bias: If true, add bias
gather_output: If true, call all-gather on output and make Y avaiable
gather_output: If true, call all-gather on output and make Y avai
l
able
to all GPUs, otherwise, every GPU will have its output
to all GPUs, otherwise, every GPU will have its output
which is Y_i = XA_i
which is Y_i = XA_i
init_method: method to initialize weights. Note that bias is always set
init_method: method to initialize weights. Note that bias is always set
...
@@ -382,13 +382,14 @@ class ColumnParallelLinear(torch.nn.Module):
...
@@ -382,13 +382,14 @@ class ColumnParallelLinear(torch.nn.Module):
else
:
else
:
self
.
register_parameter
(
'bias'
,
None
)
self
.
register_parameter
(
'bias'
,
None
)
self
.
async_tensor_model_parallel_allreduce
=
(
self
.
async_tensor_model_parallel_allreduce
=
(
not
args
.
no_
async_tensor_model_parallel_allreduce
and
args
.
async_tensor_model_parallel_allreduce
and
world_size
>
1
)
world_size
>
1
)
self
.
model_parallel_memory_opt
=
(
self
.
model_parallel_memory_opt
=
(
args
.
model_parallel_memory_opt
and
args
.
model_parallel_memory_opt
and
world_size
>
1
)
world_size
>
1
)
assert
not
self
.
async_tensor_model_parallel_allreduce
or
\
assert
not
self
.
async_tensor_model_parallel_allreduce
or
\
not
self
.
model_parallel_memory_opt
not
self
.
model_parallel_memory_opt
self
.
gradient_accumulation_fusion
=
args
.
gradient_accumulation_fusion
def
forward
(
self
,
input_
):
def
forward
(
self
,
input_
):
...
@@ -491,8 +492,8 @@ class RowParallelLinear(torch.nn.Module):
...
@@ -491,8 +492,8 @@ class RowParallelLinear(torch.nn.Module):
self
.
bias
.
zero_
()
self
.
bias
.
zero_
()
else
:
else
:
self
.
register_parameter
(
'bias'
,
None
)
self
.
register_parameter
(
'bias'
,
None
)
self
.
model_parallel_memory_opt
=
args
.
model_parallel_memory_opt
self
.
model_parallel_memory_opt
=
args
.
model_parallel_memory_opt
self
.
gradient_accumulation_fusion
=
args
.
gradient_accumulation_fusion
def
forward
(
self
,
input_
):
def
forward
(
self
,
input_
):
...
@@ -503,7 +504,9 @@ class RowParallelLinear(torch.nn.Module):
...
@@ -503,7 +504,9 @@ class RowParallelLinear(torch.nn.Module):
assert
not
self
.
model_parallel_memory_opt
assert
not
self
.
model_parallel_memory_opt
input_parallel
=
scatter_to_tensor_model_parallel_region
(
input_
)
input_parallel
=
scatter_to_tensor_model_parallel_region
(
input_
)
# Matrix multiply.
# Matrix multiply.
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
)
output_parallel
=
LinearWithGradAccumulationAndAsyncAllreduce
.
apply
(
input_parallel
,
self
.
weight
,
None
,
self
.
gradient_accumulation_fusion
,
None
)
# All-reduce across all the partitions.
# All-reduce across all the partitions.
if
self
.
model_parallel_memory_opt
:
if
self
.
model_parallel_memory_opt
:
output_
=
reduce_scatter_to_sequence_parallel_region
(
output_parallel
)
output_
=
reduce_scatter_to_sequence_parallel_region
(
output_parallel
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment