Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
9c5a830f
Commit
9c5a830f
authored
Feb 22, 2022
by
Jared Casper
Browse files
Merge branch 'slym/grad_accum_fusion' into 'main'
Gradient accumulation fusion See merge request ADLR/megatron-lm!394
parents
0ed2f6ac
b5726555
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
311 additions
and
38 deletions
+311
-38
megatron/arguments.py
megatron/arguments.py
+16
-2
megatron/fused_kernels/__init__.py
megatron/fused_kernels/__init__.py
+10
-0
megatron/fused_kernels/fused_weight_gradient_dense.cpp
megatron/fused_kernels/fused_weight_gradient_dense.cpp
+47
-0
megatron/fused_kernels/fused_weight_gradient_dense.cu
megatron/fused_kernels/fused_weight_gradient_dense.cu
+157
-0
megatron/fused_kernels/type_shim.h
megatron/fused_kernels/type_shim.h
+26
-0
megatron/model/distributed.py
megatron/model/distributed.py
+2
-2
megatron/model/language_model.py
megatron/model/language_model.py
+11
-5
megatron/mpu/__init__.py
megatron/mpu/__init__.py
+1
-0
megatron/mpu/layers.py
megatron/mpu/layers.py
+41
-29
No files found.
megatron/arguments.py
View file @
9c5a830f
...
@@ -172,6 +172,14 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -172,6 +172,14 @@ def parse_args(extra_args_provider=None, defaults={},
if
args
.
accumulate_allreduce_grads_in_fp32
:
if
args
.
accumulate_allreduce_grads_in_fp32
:
assert
args
.
DDP_impl
==
'local'
assert
args
.
DDP_impl
==
'local'
assert
args
.
use_contiguous_buffers_in_local_ddp
assert
args
.
use_contiguous_buffers_in_local_ddp
else
:
if
args
.
gradient_accumulation_fusion
:
args
.
gradient_accumulation_fusion
=
False
if
args
.
rank
==
0
:
print
(
'Gradient accumulation fusion to linear layer weight '
'gradient computation is supported only with fp32 '
'gradient accumulation. Setting gradient_accumulation_fusion '
'to False'
,
flush
=
True
)
# For torch DDP, we do not use contiguous buffer
# For torch DDP, we do not use contiguous buffer
if
args
.
DDP_impl
==
'torch'
:
if
args
.
DDP_impl
==
'torch'
:
...
@@ -521,15 +529,21 @@ def _add_training_args(parser):
...
@@ -521,15 +529,21 @@ def _add_training_args(parser):
choices
=
[
'single'
,
'cyclic'
],
choices
=
[
'single'
,
'cyclic'
],
help
=
'Single pass vs multiple pass data loader'
)
help
=
'Single pass vs multiple pass data loader'
)
group
.
add_argument
(
'--no-async-tensor-model-parallel-allreduce'
,
group
.
add_argument
(
'--no-async-tensor-model-parallel-allreduce'
,
action
=
'store_
tru
e'
,
action
=
'store_
fals
e'
,
help
=
'Disable asynchronous execution of '
help
=
'Disable asynchronous execution of '
'tensor-model-parallel all-reduce with weight '
'tensor-model-parallel all-reduce with weight '
'gradient compuation of a column-linear layer.'
)
'gradient compuation of a column-linear layer.'
,
dest
=
'async_tensor_model_parallel_allreduce'
)
group
.
add_argument
(
'--no-persist-layer-norm'
,
action
=
'store_true'
,
group
.
add_argument
(
'--no-persist-layer-norm'
,
action
=
'store_true'
,
help
=
'Disable using persistent fused layer norm kernel. '
help
=
'Disable using persistent fused layer norm kernel. '
'This kernel supports only a set of hidden sizes. Please '
'This kernel supports only a set of hidden sizes. Please '
'check persist_ln_hidden_sizes if your hidden '
'check persist_ln_hidden_sizes if your hidden '
'size is supported.'
)
'size is supported.'
)
group
.
add_argument
(
'--no-gradient-accumulation-fusion'
,
action
=
'store_false'
,
help
=
'Disable fusing gradient accumulation to weight '
'gradient computation of linear layers'
,
dest
=
'gradient_accumulation_fusion'
)
return
parser
return
parser
...
...
megatron/fused_kernels/__init__.py
View file @
9c5a830f
...
@@ -94,6 +94,16 @@ def load(args):
...
@@ -94,6 +94,16 @@ def load(args):
fused_mix_prec_layer_norm_cuda
=
_cpp_extention_load_helper
(
fused_mix_prec_layer_norm_cuda
=
_cpp_extention_load_helper
(
"fused_mix_prec_layer_norm_cuda"
,
sources
,
extra_cuda_flags
)
"fused_mix_prec_layer_norm_cuda"
,
sources
,
extra_cuda_flags
)
# =================================
# Fused gradient accumulation to weight gradient computation of linear layer
# =================================
if
args
.
gradient_accumulation_fusion
:
sources
=
[
srcpath
/
'fused_weight_gradient_dense.cpp'
,
srcpath
/
'fused_weight_gradient_dense.cu'
]
fused_dense_cuda
=
_cpp_extention_load_helper
(
"fused_dense_cuda"
,
sources
,
[])
def
_get_cuda_bare_metal_version
(
cuda_dir
):
def
_get_cuda_bare_metal_version
(
cuda_dir
):
raw_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
raw_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
...
...
megatron/fused_kernels/fused_weight_gradient_dense.cpp
0 → 100644
View file @
9c5a830f
#include <torch/torch.h>
#include <torch/extension.h>
#include <vector>
#include <stdio.h>
#include "type_shim.h"
template
<
typename
T
>
int
wgrad_gemm_accum_fp32_cuda
(
T
*
input
,
T
*
d_output
,
float
*
d_weight
,
int
in_dim
,
int
hidden_dim
,
int
out_dim
);
void
wgrad_gemm_accum_fp32
(
const
at
::
Tensor
input
,
const
at
::
Tensor
d_output
,
at
::
Tensor
d_weight
)
{
at
::
Tensor
input_2d
,
d_output_2d
;
// input tensor: collapse to the first dim
auto
in_sizes
=
input
.
sizes
();
if
(
input
.
dim
()
>
2
)
{
input_2d
=
input
.
view
({
-
1
,
in_sizes
[
in_sizes
.
size
()
-
1
]});
}
else
{
input_2d
=
input
;
}
// d_output tensor: collapse to the first dim
auto
d_out_sizes
=
d_output
.
sizes
();
if
(
d_output
.
dim
()
>
2
)
{
d_output_2d
=
d_output
.
view
({
-
1
,
d_out_sizes
[
d_out_sizes
.
size
()
-
1
]});
}
else
{
d_output_2d
=
d_output
;
}
int
hidden_dim
=
input_2d
.
size
(
0
);
int
in_dim
=
input_2d
.
size
(
1
);
int
out_dim
=
d_weight
.
size
(
0
);
DISPATCH_HALF_BFLOAT_AND_FLOAT
(
input_2d
.
scalar_type
(),
"wgrad_gemm_accum_fp32"
,
int
result
=
wgrad_gemm_accum_fp32_cuda
<
scalar_t
>
(
input_2d
.
data_ptr
<
scalar_t
>
(),
d_output_2d
.
data_ptr
<
scalar_t
>
(),
d_weight
.
data_ptr
<
float
>
(),
in_dim
,
hidden_dim
,
out_dim
);
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"wgrad_gemm_accum_fp32"
,
&
wgrad_gemm_accum_fp32
,
"wgrad gemm accum in fp32"
);
}
megatron/fused_kernels/fused_weight_gradient_dense.cu
0 → 100644
View file @
9c5a830f
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <torch/torch.h>
/* Includes, cuda */
#include <cublas_v2.h>
#include <cuda_runtime.h>
// BF16 Tensor core wrapper around cublas GEMMEx
cublasStatus_t
gemmex_wrapper
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
at
::
BFloat16
*
A
,
int
lda
,
at
::
BFloat16
*
B
,
int
ldb
,
const
float
*
beta
,
float
*
C
,
int
ldc
)
{
return
cublasGemmEx
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
CUDA_R_16BF
,
lda
,
B
,
CUDA_R_16BF
,
ldb
,
beta
,
C
,
CUDA_R_32F
,
ldc
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
}
// FP16 Tensor core wrapper around cublas GEMMEx
cublasStatus_t
gemmex_wrapper
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
at
::
Half
*
A
,
int
lda
,
at
::
Half
*
B
,
int
ldb
,
const
float
*
beta
,
float
*
C
,
int
ldc
)
{
return
cublasGemmEx
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
CUDA_R_16F
,
lda
,
B
,
CUDA_R_16F
,
ldb
,
beta
,
C
,
CUDA_R_32F
,
ldc
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
}
// FP32 Tensor core wrapper around cublas GEMMEx
cublasStatus_t
gemmex_wrapper
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
float
*
A
,
int
lda
,
float
*
B
,
int
ldb
,
const
float
*
beta
,
float
*
C
,
int
ldc
)
{
return
cublasGemmEx
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
CUDA_R_32F
,
lda
,
B
,
CUDA_R_32F
,
ldb
,
beta
,
C
,
CUDA_R_32F
,
ldc
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
}
template
<
typename
T
>
int
wgrad_gemm_accum_fp32_cuda
(
T
*
input
,
T
*
d_output
,
float
*
d_weight
,
int
in_dim
,
int
hidden_dim
,
int
out_dim
)
{
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
cudaStream_t
stream
;
cublasGetStream
(
handle
,
&
stream
);
const
float
alpha
=
1.0
;
const
float
beta
=
1.0
;
int
status
=
1
;
status
=
gemmex_wrapper
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
in_dim
,
out_dim
,
hidden_dim
,
&
alpha
,
input
,
in_dim
,
d_output
,
out_dim
,
&
beta
,
d_weight
,
in_dim
);
return
status
;
}
template
int
wgrad_gemm_accum_fp32_cuda
<
at
::
Half
>(
at
::
Half
*
input
,
at
::
Half
*
d_output
,
float
*
d_weight
,
int
in_dim
,
int
hidden_dim
,
int
out_dim
);
template
int
wgrad_gemm_accum_fp32_cuda
<
at
::
BFloat16
>(
at
::
BFloat16
*
input
,
at
::
BFloat16
*
d_output
,
float
*
d_weight
,
int
in_dim
,
int
hidden_dim
,
int
out_dim
);
template
int
wgrad_gemm_accum_fp32_cuda
<
float
>(
float
*
input
,
float
*
d_output
,
float
*
d_weight
,
int
in_dim
,
int
hidden_dim
,
int
out_dim
);
megatron/fused_kernels/type_shim.h
View file @
9c5a830f
...
@@ -39,6 +39,32 @@
...
@@ -39,6 +39,32 @@
}
}
#define DISPATCH_HALF_BFLOAT_AND_FLOAT(TYPE, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Half: \
{ \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch(TYPEIN) \
switch(TYPEIN) \
...
...
megatron/model/distributed.py
View file @
9c5a830f
...
@@ -164,13 +164,13 @@ class DistributedDataParallel(DistributedDataParallelBase):
...
@@ -164,13 +164,13 @@ class DistributedDataParallel(DistributedDataParallelBase):
grad_acc
.
register_hook
(
self
.
_make_param_hook
(
param
))
grad_acc
.
register_hook
(
self
.
_make_param_hook
(
param
))
self
.
grad_accs
.
append
(
grad_acc
)
self
.
grad_accs
.
append
(
grad_acc
)
def
_make_param_hook
(
self
,
param
):
def
_make_param_hook
(
self
,
param
):
"""Create the all-reduce hook for backprop."""
"""Create the all-reduce hook for backprop."""
# Hook used for back-prop.
# Hook used for back-prop.
def
param_hook
(
*
unused
):
def
param_hook
(
*
unused
):
# Add the gradient to the buffer.
# Add the gradient to the buffer.
if
param
.
grad
.
data
is
not
None
:
if
param
.
grad
is
not
None
:
# The gradient function of linear layers is fused with GEMMs
param
.
main_grad
.
add_
(
param
.
grad
.
data
)
param
.
main_grad
.
add_
(
param
.
grad
.
data
)
# Now we can deallocate grad memory.
# Now we can deallocate grad memory.
param
.
grad
=
None
param
.
grad
=
None
...
...
megatron/model/language_model.py
View file @
9c5a830f
...
@@ -29,13 +29,19 @@ from megatron.model.utils import init_method_normal, scaled_init_method_normal
...
@@ -29,13 +29,19 @@ from megatron.model.utils import init_method_normal, scaled_init_method_normal
def
parallel_lm_logits
(
input_
,
word_embeddings_weight
,
parallel_output
,
def
parallel_lm_logits
(
input_
,
word_embeddings_weight
,
parallel_output
,
bias
=
None
):
bias
=
None
):
"""LM logits using word embedding weights."""
"""LM logits using word embedding weights."""
args
=
get_args
()
# Parallel logits.
# Parallel logits.
input_parallel
=
mpu
.
copy_to_tensor_model_parallel_region
(
input_
)
if
args
.
async_tensor_model_parallel_allreduce
:
# Matrix multiply.
input_parallel
=
input_
if
bias
is
None
:
async_grad_allreduce
=
mpu
.
get_tensor_model_parallel_world_size
()
>
1
logits_parallel
=
F
.
linear
(
input_parallel
,
word_embeddings_weight
)
else
:
else
:
logits_parallel
=
F
.
linear
(
input_parallel
,
word_embeddings_weight
,
bias
)
input_parallel
=
mpu
.
copy_to_tensor_model_parallel_region
(
input_
)
async_grad_allreduce
=
False
# Matrix multiply.
logits_parallel
=
mpu
.
LinearWithGradAccumulationAndAsyncAllreduce
.
apply
(
input_parallel
,
word_embeddings_weight
,
bias
,
args
.
gradient_accumulation_fusion
,
async_grad_allreduce
)
# Gather if needed.
# Gather if needed.
if
parallel_output
:
if
parallel_output
:
return
logits_parallel
return
logits_parallel
...
...
megatron/mpu/__init__.py
View file @
9c5a830f
...
@@ -49,6 +49,7 @@ from .initialize import get_virtual_pipeline_model_parallel_rank, set_virtual_pi
...
@@ -49,6 +49,7 @@ from .initialize import get_virtual_pipeline_model_parallel_rank, set_virtual_pi
from
.initialize
import
initialize_model_parallel
from
.initialize
import
initialize_model_parallel
from
.initialize
import
model_parallel_is_initialized
from
.initialize
import
model_parallel_is_initialized
from
.layers
import
LinearWithGradAccumulationAndAsyncAllreduce
from
.layers
import
ColumnParallelLinear
from
.layers
import
ColumnParallelLinear
from
.layers
import
RowParallelLinear
from
.layers
import
RowParallelLinear
from
.layers
import
VocabParallelEmbedding
from
.layers
import
VocabParallelEmbedding
...
...
megatron/mpu/layers.py
View file @
9c5a830f
...
@@ -199,15 +199,18 @@ class VocabParallelEmbedding(torch.nn.Module):
...
@@ -199,15 +199,18 @@ class VocabParallelEmbedding(torch.nn.Module):
return
output
return
output
class
ColumnParallelLinearWith
AsyncAllreduce
(
torch
.
autograd
.
Function
):
class
LinearWithGradAccumulationAnd
AsyncAllreduce
(
torch
.
autograd
.
Function
):
"""
"""
Column-parallel l
inear layer execution with asynchronous all-reduce
L
inear layer execution with asynchronous all-reduce
and gradient accumulation
execut
ion in backprop.
fus
ion in backprop.
"""
"""
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
input
,
weight
,
bias
):
def
forward
(
ctx
,
input
,
weight
,
bias
,
gradient_accumulation_fusion
,
async_grad_allreduce
):
ctx
.
save_for_backward
(
input
,
weight
)
ctx
.
save_for_backward
(
input
,
weight
)
ctx
.
use_bias
=
bias
is
not
None
ctx
.
use_bias
=
bias
is
not
None
ctx
.
gradient_accumulation_fusion
=
gradient_accumulation_fusion
ctx
.
async_grad_allreduce
=
async_grad_allreduce
output
=
torch
.
matmul
(
input
,
weight
.
t
())
output
=
torch
.
matmul
(
input
,
weight
.
t
())
if
bias
is
not
None
:
if
bias
is
not
None
:
output
=
output
+
bias
output
=
output
+
bias
...
@@ -215,19 +218,32 @@ class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function):
...
@@ -215,19 +218,32 @@ class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_output
):
def
backward
(
ctx
,
grad_output
):
import
fused_dense_cuda
input
,
weight
=
ctx
.
saved_tensors
input
,
weight
=
ctx
.
saved_tensors
use_bias
=
ctx
.
use_bias
use_bias
=
ctx
.
use_bias
grad_input
=
grad_output
.
matmul
(
weight
)
grad_input
=
grad_output
.
matmul
(
weight
)
# Asyncronous all-reduce
handle
=
torch
.
distributed
.
all_reduce
(
# Convert the tensor shapes to 2D for execution compatibility
grad_input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
grad_output
=
grad_output
.
view
(
grad_output
.
shape
[
0
]
*
grad_output
.
shape
[
1
],
# Delay the start of weight gradient computation shortly (3us) to have
grad_output
.
shape
[
2
])
# all-reduce scheduled first and have GPU resources allocated
input
=
input
.
view
(
input
.
shape
[
0
]
*
input
.
shape
[
1
],
input
.
shape
[
2
])
_
=
torch
.
empty
(
1
,
device
=
grad_output
.
device
)
+
1
grad_weight
=
grad_output
.
t
().
matmul
(
input
)
if
ctx
.
async_grad_allreduce
:
# Asynchronous all-reduce
handle
=
torch
.
distributed
.
all_reduce
(
grad_input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_
=
torch
.
empty
(
1
,
device
=
grad_output
.
device
)
+
1
if
ctx
.
gradient_accumulation_fusion
:
fused_dense_cuda
.
wgrad_gemm_accum_fp32
(
input
,
grad_output
,
weight
.
main_grad
)
grad_weight
=
None
else
:
grad_weight
=
grad_output
.
t
().
matmul
(
input
)
grad_bias
=
grad_output
.
sum
(
dim
=
0
)
if
use_bias
else
None
grad_bias
=
grad_output
.
sum
(
dim
=
0
)
if
use_bias
else
None
handle
.
wait
()
if
ctx
.
async_grad_allreduce
:
return
grad_input
,
grad_weight
,
grad_bias
handle
.
wait
()
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
class
ColumnParallelLinear
(
torch
.
nn
.
Module
):
class
ColumnParallelLinear
(
torch
.
nn
.
Module
):
...
@@ -240,7 +256,7 @@ class ColumnParallelLinear(torch.nn.Module):
...
@@ -240,7 +256,7 @@ class ColumnParallelLinear(torch.nn.Module):
input_size: first dimension of matrix A.
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
output_size: second dimension of matrix A.
bias: If true, add bias
bias: If true, add bias
gather_output: If true, call all-gather on output and make Y avaiable
gather_output: If true, call all-gather on output and make Y avai
l
able
to all GPUs, otherwise, every GPU will have its output
to all GPUs, otherwise, every GPU will have its output
which is Y_i = XA_i
which is Y_i = XA_i
init_method: method to initialize weights. Note that bias is always set
init_method: method to initialize weights. Note that bias is always set
...
@@ -305,29 +321,23 @@ class ColumnParallelLinear(torch.nn.Module):
...
@@ -305,29 +321,23 @@ class ColumnParallelLinear(torch.nn.Module):
else
:
else
:
self
.
register_parameter
(
'bias'
,
None
)
self
.
register_parameter
(
'bias'
,
None
)
self
.
async_tensor_model_parallel_allreduce
=
(
self
.
async_tensor_model_parallel_allreduce
=
(
not
args
.
no_
async_tensor_model_parallel_allreduce
and
args
.
async_tensor_model_parallel_allreduce
and
world_size
>
1
)
world_size
>
1
)
self
.
gradient_accumulation_fusion
=
args
.
gradient_accumulation_fusion
def
forward
(
self
,
input_
):
def
forward
(
self
,
input_
):
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
if
self
.
async_tensor_model_parallel_allreduce
:
if
self
.
async_tensor_model_parallel_allreduce
:
input_shape
=
input_
.
shape
input_parallel
=
input_
input_
=
input_
.
view
(
input_shape
[
0
]
*
input_shape
[
1
],
input_shape
[
2
])
# Maxtrix multiply with asynchronouse all-reduce execution
output_parallel
=
ColumnParallelLinearWithAsyncAllreduce
.
apply
(
input_
,
self
.
weight
,
bias
)
output_parallel
=
output_parallel
.
view
(
input_shape
[
0
],
input_shape
[
1
],
output_parallel
.
shape
[
1
])
else
:
else
:
# Set up backprop all-reduce.
# Set up backprop all-reduce.
input_parallel
=
copy_to_tensor_model_parallel_region
(
input_
)
input_parallel
=
copy_to_tensor_model_parallel_region
(
input_
)
# Matrix multiply.
# Matrix
multiply
.
output_parallel
=
LinearWithGradAccu
mul
a
ti
onAndAsyncAllreduce
.
ap
ply
(
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
,
bias
)
input_parallel
,
self
.
weight
,
bias
,
self
.
gradient_accumulation_fusion
,
self
.
async_tensor_model_parallel_allreduce
)
if
self
.
gather_output
:
if
self
.
gather_output
:
# All-gather across the partitions.
# All-gather across the partitions.
output
=
gather_from_tensor_model_parallel_region
(
output_parallel
)
output
=
gather_from_tensor_model_parallel_region
(
output_parallel
)
...
@@ -415,7 +425,7 @@ class RowParallelLinear(torch.nn.Module):
...
@@ -415,7 +425,7 @@ class RowParallelLinear(torch.nn.Module):
self
.
bias
.
zero_
()
self
.
bias
.
zero_
()
else
:
else
:
self
.
register_parameter
(
'bias'
,
None
)
self
.
register_parameter
(
'bias'
,
None
)
self
.
gradient_accumulation_fusion
=
args
.
gradient_accumulation_fusion
def
forward
(
self
,
input_
):
def
forward
(
self
,
input_
):
...
@@ -425,7 +435,9 @@ class RowParallelLinear(torch.nn.Module):
...
@@ -425,7 +435,9 @@ class RowParallelLinear(torch.nn.Module):
else
:
else
:
input_parallel
=
scatter_to_tensor_model_parallel_region
(
input_
)
input_parallel
=
scatter_to_tensor_model_parallel_region
(
input_
)
# Matrix multiply.
# Matrix multiply.
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
)
output_parallel
=
LinearWithGradAccumulationAndAsyncAllreduce
.
apply
(
input_parallel
,
self
.
weight
,
None
,
self
.
gradient_accumulation_fusion
,
None
)
# All-reduce across all the partitions.
# All-reduce across all the partitions.
output_
=
reduce_from_tensor_model_parallel_region
(
output_parallel
)
output_
=
reduce_from_tensor_model_parallel_region
(
output_parallel
)
if
not
self
.
skip_bias_add
:
if
not
self
.
skip_bias_add
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment