Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
b5726555
Commit
b5726555
authored
Feb 18, 2022
by
Sangkug Lym
Browse files
support fp32 training and fix embedding update
parent
83b1e42f
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
73 additions
and
22 deletions
+73
-22
megatron/arguments.py
megatron/arguments.py
+1
-1
megatron/fused_kernels/fused_weight_gradient_dense.cpp
megatron/fused_kernels/fused_weight_gradient_dense.cpp
+1
-1
megatron/fused_kernels/fused_weight_gradient_dense.cu
megatron/fused_kernels/fused_weight_gradient_dense.cu
+39
-0
megatron/fused_kernels/type_shim.h
megatron/fused_kernels/type_shim.h
+26
-0
megatron/model/distributed.py
megatron/model/distributed.py
+6
-13
megatron/mpu/layers.py
megatron/mpu/layers.py
+0
-7
No files found.
megatron/arguments.py
View file @
b5726555
...
@@ -541,7 +541,7 @@ def _add_training_args(parser):
...
@@ -541,7 +541,7 @@ def _add_training_args(parser):
'size is supported.'
)
'size is supported.'
)
group
.
add_argument
(
'--no-gradient-accumulation-fusion'
,
group
.
add_argument
(
'--no-gradient-accumulation-fusion'
,
action
=
'store_false'
,
action
=
'store_false'
,
help
=
'Disable fu
i
sng gradient accumulation to weight '
help
=
'Disable fus
i
ng gradient accumulation to weight '
'gradient computation of linear layers'
,
'gradient computation of linear layers'
,
dest
=
'gradient_accumulation_fusion'
)
dest
=
'gradient_accumulation_fusion'
)
return
parser
return
parser
...
...
megatron/fused_kernels/fused_weight_gradient_dense.cpp
View file @
b5726555
...
@@ -31,7 +31,7 @@ void wgrad_gemm_accum_fp32(const at::Tensor input, const at::Tensor d_output, at
...
@@ -31,7 +31,7 @@ void wgrad_gemm_accum_fp32(const at::Tensor input, const at::Tensor d_output, at
int
in_dim
=
input_2d
.
size
(
1
);
int
in_dim
=
input_2d
.
size
(
1
);
int
out_dim
=
d_weight
.
size
(
0
);
int
out_dim
=
d_weight
.
size
(
0
);
DISPATCH_HALF_AND_
B
FLOAT
(
input_2d
.
scalar_type
(),
"wgrad_gemm_accum_fp32"
,
DISPATCH_HALF_
BFLOAT_
AND_FLOAT
(
input_2d
.
scalar_type
(),
"wgrad_gemm_accum_fp32"
,
int
result
=
wgrad_gemm_accum_fp32_cuda
<
scalar_t
>
(
int
result
=
wgrad_gemm_accum_fp32_cuda
<
scalar_t
>
(
input_2d
.
data_ptr
<
scalar_t
>
(),
input_2d
.
data_ptr
<
scalar_t
>
(),
d_output_2d
.
data_ptr
<
scalar_t
>
(),
d_output_2d
.
data_ptr
<
scalar_t
>
(),
...
...
megatron/fused_kernels/fused_weight_gradient_dense.cu
View file @
b5726555
...
@@ -87,6 +87,44 @@ cublasStatus_t gemmex_wrapper(
...
@@ -87,6 +87,44 @@ cublasStatus_t gemmex_wrapper(
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
}
}
// FP32 Tensor core wrapper around cublas GEMMEx
cublasStatus_t
gemmex_wrapper
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
float
*
A
,
int
lda
,
float
*
B
,
int
ldb
,
const
float
*
beta
,
float
*
C
,
int
ldc
)
{
return
cublasGemmEx
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
CUDA_R_32F
,
lda
,
B
,
CUDA_R_32F
,
ldb
,
beta
,
C
,
CUDA_R_32F
,
ldc
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
}
template
<
typename
T
>
template
<
typename
T
>
int
wgrad_gemm_accum_fp32_cuda
(
T
*
input
,
T
*
d_output
,
float
*
d_weight
,
int
in_dim
,
int
hidden_dim
,
int
out_dim
)
{
int
wgrad_gemm_accum_fp32_cuda
(
T
*
input
,
T
*
d_output
,
float
*
d_weight
,
int
in_dim
,
int
hidden_dim
,
int
out_dim
)
{
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
...
@@ -116,3 +154,4 @@ int wgrad_gemm_accum_fp32_cuda(T *input, T *d_output, float *d_weight, int in_di
...
@@ -116,3 +154,4 @@ int wgrad_gemm_accum_fp32_cuda(T *input, T *d_output, float *d_weight, int in_di
template
int
wgrad_gemm_accum_fp32_cuda
<
at
::
Half
>(
at
::
Half
*
input
,
at
::
Half
*
d_output
,
float
*
d_weight
,
int
in_dim
,
int
hidden_dim
,
int
out_dim
);
template
int
wgrad_gemm_accum_fp32_cuda
<
at
::
Half
>(
at
::
Half
*
input
,
at
::
Half
*
d_output
,
float
*
d_weight
,
int
in_dim
,
int
hidden_dim
,
int
out_dim
);
template
int
wgrad_gemm_accum_fp32_cuda
<
at
::
BFloat16
>(
at
::
BFloat16
*
input
,
at
::
BFloat16
*
d_output
,
float
*
d_weight
,
int
in_dim
,
int
hidden_dim
,
int
out_dim
);
template
int
wgrad_gemm_accum_fp32_cuda
<
at
::
BFloat16
>(
at
::
BFloat16
*
input
,
at
::
BFloat16
*
d_output
,
float
*
d_weight
,
int
in_dim
,
int
hidden_dim
,
int
out_dim
);
template
int
wgrad_gemm_accum_fp32_cuda
<
float
>(
float
*
input
,
float
*
d_output
,
float
*
d_weight
,
int
in_dim
,
int
hidden_dim
,
int
out_dim
);
megatron/fused_kernels/type_shim.h
View file @
b5726555
...
@@ -39,6 +39,32 @@
...
@@ -39,6 +39,32 @@
}
}
#define DISPATCH_HALF_BFLOAT_AND_FLOAT(TYPE, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Half: \
{ \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch(TYPEIN) \
switch(TYPEIN) \
...
...
megatron/model/distributed.py
View file @
b5726555
...
@@ -168,21 +168,14 @@ class DistributedDataParallel(DistributedDataParallelBase):
...
@@ -168,21 +168,14 @@ class DistributedDataParallel(DistributedDataParallelBase):
"""Create the all-reduce hook for backprop."""
"""Create the all-reduce hook for backprop."""
# Hook used for back-prop.
# Hook used for back-prop.
def
param_hook
(
*
unused
):
def
param_hook
(
*
unused
):
if
not
self
.
skip_gradient_func
(
param
):
# Add the gradient to the buffer.
# Add the gradient to the buffer.
if
param
.
grad
is
not
None
:
if
param
.
grad
.
data
is
not
None
:
# The gradient function of linear layers is fused with GEMMs
# The gradient function of linear layers is fused with GEMMs
param
.
main_grad
.
add_
(
param
.
grad
.
data
)
param
.
main_grad
.
add_
(
param
.
grad
.
data
)
# Now we can deallocate grad memory.
# Now we can deallocate grad memory.
param
.
grad
=
None
param
.
grad
=
None
return
param_hook
return
param_hook
def
skip_gradient_func
(
self
,
param
):
# Skip gradient function of linear layers
# Gradient accumulation is fused to weight gradient computation operators
if
getattr
(
param
,
'fuse_gradient_accumulation'
,
False
):
return
True
return
False
def
zero_grad_buffer
(
self
):
def
zero_grad_buffer
(
self
):
"""Set the grad buffer data to zero. Needs to be called at the
"""Set the grad buffer data to zero. Needs to be called at the
...
...
megatron/mpu/layers.py
View file @
b5726555
...
@@ -175,8 +175,6 @@ class VocabParallelEmbedding(torch.nn.Module):
...
@@ -175,8 +175,6 @@ class VocabParallelEmbedding(torch.nn.Module):
device
=
torch
.
cuda
.
current_device
(),
dtype
=
args
.
params_dtype
))
device
=
torch
.
cuda
.
current_device
(),
dtype
=
args
.
params_dtype
))
_initialize_affine_weight_gpu
(
self
.
weight
,
init_method
,
_initialize_affine_weight_gpu
(
self
.
weight
,
init_method
,
partition_dim
=
0
,
stride
=
1
)
partition_dim
=
0
,
stride
=
1
)
setattr
(
self
.
weight
,
'fuse_gradient_accumulation'
,
args
.
gradient_accumulation_fusion
)
def
forward
(
self
,
input_
):
def
forward
(
self
,
input_
):
if
self
.
tensor_model_parallel_size
>
1
:
if
self
.
tensor_model_parallel_size
>
1
:
...
@@ -241,7 +239,6 @@ class LinearWithGradAccumulationAndAsyncAllreduce(torch.autograd.Function):
...
@@ -241,7 +239,6 @@ class LinearWithGradAccumulationAndAsyncAllreduce(torch.autograd.Function):
fused_dense_cuda
.
wgrad_gemm_accum_fp32
(
input
,
grad_output
,
weight
.
main_grad
)
fused_dense_cuda
.
wgrad_gemm_accum_fp32
(
input
,
grad_output
,
weight
.
main_grad
)
grad_weight
=
None
grad_weight
=
None
else
:
else
:
# Matrix multiply with asynchronous all-reduce execution
grad_weight
=
grad_output
.
t
().
matmul
(
input
)
grad_weight
=
grad_output
.
t
().
matmul
(
input
)
grad_bias
=
grad_output
.
sum
(
dim
=
0
)
if
use_bias
else
None
grad_bias
=
grad_output
.
sum
(
dim
=
0
)
if
use_bias
else
None
if
ctx
.
async_grad_allreduce
:
if
ctx
.
async_grad_allreduce
:
...
@@ -327,8 +324,6 @@ class ColumnParallelLinear(torch.nn.Module):
...
@@ -327,8 +324,6 @@ class ColumnParallelLinear(torch.nn.Module):
args
.
async_tensor_model_parallel_allreduce
and
args
.
async_tensor_model_parallel_allreduce
and
world_size
>
1
)
world_size
>
1
)
self
.
gradient_accumulation_fusion
=
args
.
gradient_accumulation_fusion
self
.
gradient_accumulation_fusion
=
args
.
gradient_accumulation_fusion
setattr
(
self
.
weight
,
'fuse_gradient_accumulation'
,
self
.
gradient_accumulation_fusion
)
def
forward
(
self
,
input_
):
def
forward
(
self
,
input_
):
...
@@ -431,8 +426,6 @@ class RowParallelLinear(torch.nn.Module):
...
@@ -431,8 +426,6 @@ class RowParallelLinear(torch.nn.Module):
else
:
else
:
self
.
register_parameter
(
'bias'
,
None
)
self
.
register_parameter
(
'bias'
,
None
)
self
.
gradient_accumulation_fusion
=
args
.
gradient_accumulation_fusion
self
.
gradient_accumulation_fusion
=
args
.
gradient_accumulation_fusion
setattr
(
self
.
weight
,
'fuse_gradient_accumulation'
,
self
.
gradient_accumulation_fusion
)
def
forward
(
self
,
input_
):
def
forward
(
self
,
input_
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment