Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
51a2e6b0
Commit
51a2e6b0
authored
Sep 11, 2020
by
Vijay Korthikanti
Committed by
Jared Casper
Sep 11, 2020
Browse files
Various speed optimizations.
parent
12518332
Changes
13
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
1103 additions
and
163 deletions
+1103
-163
megatron/arguments.py
megatron/arguments.py
+13
-1
megatron/fused_kernels/__init__.py
megatron/fused_kernels/__init__.py
+53
-0
megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp
...tron/fused_kernels/scaled_upper_triang_masked_softmax.cpp
+69
-0
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
+439
-0
megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu
.../fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu
+89
-0
megatron/model/fused_bias_gelu.py
megatron/model/fused_bias_gelu.py
+60
-0
megatron/model/fused_softmax.py
megatron/model/fused_softmax.py
+94
-0
megatron/model/language_model.py
megatron/model/language_model.py
+2
-12
megatron/model/transformer.py
megatron/model/transformer.py
+244
-132
megatron/mpu/layers.py
megatron/mpu/layers.py
+25
-9
megatron/training.py
megatron/training.py
+12
-3
megatron/utils.py
megatron/utils.py
+3
-3
pretrain_gpt2.py
pretrain_gpt2.py
+0
-3
No files found.
megatron/arguments.py
View file @
51a2e6b0
...
...
@@ -19,7 +19,7 @@ import argparse
import
os
import
torch
from
megatron
import
fused_kernels
def
parse_args
(
extra_args_provider
=
None
,
defaults
=
{},
ignore_unknown_args
=
False
):
...
...
@@ -118,6 +118,10 @@ def parse_args(extra_args_provider=None, defaults={},
'for distribute-checkpointed-activations to work you '
\
'need to enable checkpoint-activations'
# load scaled_upper_triang_masked_softmax_fusion kernel
if
args
.
scaled_upper_triang_masked_softmax_fusion
:
fused_kernels
.
load_scaled_upper_triang_masked_softmax_fusion_kernel
()
_print_args
(
args
)
return
args
...
...
@@ -221,6 +225,14 @@ def _add_training_args(parser):
'by this value.'
)
group
.
add_argument
(
'--tensorboard-dir'
,
type
=
str
,
default
=
None
,
help
=
'Write TensorBoard logs to this directory.'
)
group
.
add_argument
(
'--scaled-upper-triang-masked-softmax-fusion'
,
action
=
'store_true'
,
help
=
'Enable fusion of query_key_value_scaling '
'time (upper diagonal) masking, softmax.'
)
group
.
add_argument
(
'--bias-gelu-fusion'
,
action
=
'store_true'
,
help
=
'Enable bias and gelu fusion.'
)
group
.
add_argument
(
'--bias-dropout-fusion'
,
action
=
'store_true'
,
help
=
'Enable bias and dropout fusion.'
)
return
parser
...
...
megatron/fused_kernels/__init__.py
0 → 100644
View file @
51a2e6b0
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
pathlib
import
subprocess
from
torch.utils
import
cpp_extension
def
load_scaled_upper_triang_masked_softmax_fusion_kernel
():
def
get_cuda_bare_metal_version
(
cuda_dir
):
raw_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
output
=
raw_output
.
split
()
release_idx
=
output
.
index
(
"release"
)
+
1
release
=
output
[
release_idx
].
split
(
"."
)
bare_metal_major
=
release
[
0
]
bare_metal_minor
=
release
[
1
][
0
]
return
raw_output
,
bare_metal_major
,
bare_metal_minor
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag
=
[]
_
,
bare_metal_major
,
_
=
get_cuda_bare_metal_version
(
cpp_extension
.
CUDA_HOME
)
if
int
(
bare_metal_major
)
>=
11
:
cc_flag
.
append
(
'-gencode'
)
cc_flag
.
append
(
'arch=compute_80,code=sm_80'
)
srcpath
=
pathlib
.
Path
(
__file__
).
parent
.
absolute
()
scaled_upper_triang_masked_softmax_cuda
=
cpp_extension
.
load
(
name
=
'scaled_upper_triang_masked_softmax_cuda'
,
sources
=
[
srcpath
/
'scaled_upper_triang_masked_softmax.cpp'
,
srcpath
/
'scaled_upper_triang_masked_softmax_cuda.cu'
],
extra_cflags
=
[
'-O3'
,],
extra_cuda_cflags
=
[
'-O3'
,
'-gencode'
,
'arch=compute_70,code=sm_70'
,
'-U__CUDA_NO_HALF_OPERATORS__'
,
'-U__CUDA_NO_HALF_CONVERSIONS__'
,
'--expt-relaxed-constexpr'
,
'--expt-extended-lambda'
,
'--use_fast_math'
]
+
cc_flag
,
verbose
=
True
)
megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp
0 → 100644
View file @
51a2e6b0
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace
multihead_attn
{
namespace
fused_softmax
{
namespace
scaled_upper_triang_masked_softmax
{
torch
::
Tensor
fwd_cuda
(
torch
::
Tensor
const
&
input
,
float
scale_factor
);
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
softmax_results
,
float
scale_factor
);
torch
::
Tensor
fwd
(
torch
::
Tensor
const
&
input
,
float
scale_factor
)
{
AT_ASSERTM
(
input
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
return
fwd_cuda
(
input
,
scale_factor
);
}
torch
::
Tensor
bwd
(
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
softmax_results
,
float
scale_factor
)
{
AT_ASSERTM
(
output_grads
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
return
bwd_cuda
(
output_grads
,
softmax_results
,
scale_factor
);
}
}
// end namespace scaled_upper_triang_masked_softmax
}
// end namespace fused_softmax
}
// end namespace multihead_attn
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
multihead_attn
::
fused_softmax
::
scaled_upper_triang_masked_softmax
::
fwd
,
"Self Multihead Attention scaled, time masked softmax -- Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
fused_softmax
::
scaled_upper_triang_masked_softmax
::
bwd
,
"Self Multihead Attention scaled, time masked softmax -- Backward."
);
}
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
0 → 100644
View file @
51a2e6b0
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <assert.h>
#include <cuda_fp16.h>
#include <cfloat>
#include <limits>
#include <stdint.h>
#include <cuda_fp16.h>
#include <c10/macros/Macros.h>
namespace
{
int
log2_ceil
(
int
value
)
{
int
log2_value
=
0
;
while
((
1
<<
log2_value
)
<
value
)
++
log2_value
;
return
log2_value
;
}
template
<
typename
T
>
struct
Add
{
__device__
__forceinline__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
+
b
;
}
};
template
<
typename
T
>
struct
Max
{
__device__
__forceinline__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
<
b
?
b
:
a
;
}
};
template
<
typename
T
>
__device__
__forceinline__
T
WARP_SHFL_XOR_NATIVE
(
T
value
,
int
laneMask
,
int
width
=
warpSize
,
unsigned
int
mask
=
0xffffffff
)
{
#if CUDA_VERSION >= 9000
return
__shfl_xor_sync
(
mask
,
value
,
laneMask
,
width
);
#else
return
__shfl_xor
(
value
,
laneMask
,
width
);
#endif
}
template
<
typename
acc_t
,
int
WARP_BATCH
,
int
WARP_SIZE
,
template
<
typename
>
class
ReduceOp
>
__device__
__forceinline__
void
warp_reduce
(
acc_t
*
sum
)
{
ReduceOp
<
acc_t
>
r
;
#pragma unroll
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
acc_t
b
=
WARP_SHFL_XOR_NATIVE
(
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
=
r
(
sum
[
i
],
b
);
}
}
}
/*
* Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling
* 2) Implicit time (diagonal masking)
*/
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
log2_elements
>
__global__
void
scaled_upper_triang_masked_softmax_warp_forward
(
output_t
*
dst
,
const
input_t
*
src
,
const
acc_t
scale
,
int
batch_size
,
int
stride
,
int
element_count
)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_forward_kernel.
constexpr
int
next_power_of_two
=
1
<<
log2_elements
;
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
y
+
threadIdx
.
y
)
*
gridDim
.
x
*
WARP_BATCH
+
blockIdx
.
x
;
int
local_seq
=
blockIdx
.
x
+
1
;
int
warp_iteration_limit
=
(
local_seq
+
WARP_SIZE
-
1
)
/
WARP_SIZE
;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int
local_batches
=
batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
local_batches
=
WARP_BATCH
;
// there might be multiple batches per warp. compute the index within the batch
int
local_idx
=
threadIdx
.
x
;
src
+=
first_batch
*
stride
+
local_idx
;
dst
+=
first_batch
*
stride
+
local_idx
;
// load data from global memory
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
local_seq
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
elements
[
i
][
it
]
=
(
acc_t
)
src
[
i
*
element_count
*
stride
+
it
*
WARP_SIZE
]
*
scale
;
}
else
{
elements
[
i
][
it
]
=
-
std
::
numeric_limits
<
acc_t
>::
infinity
();
}
}
}
// compute max_value
acc_t
max_value
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
max_value
[
i
]
=
elements
[
i
][
0
];
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
max_value
[
i
]
=
(
max_value
[
i
]
>
elements
[
i
][
it
])
?
max_value
[
i
]
:
elements
[
i
][
it
];
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Max
>
(
max_value
);
acc_t
sum
[
WARP_BATCH
]
{
0.0
f
};
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
if
(
it
<
warp_iteration_limit
)
{
elements
[
i
][
it
]
=
std
::
exp
((
elements
[
i
][
it
]
-
max_value
[
i
]));
sum
[
i
]
+=
elements
[
i
][
it
];
}
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Add
>
(
sum
);
// store result
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
local_seq
)
{
dst
[
i
*
element_count
*
stride
+
it
*
WARP_SIZE
]
=
(
output_t
)(
elements
[
i
][
it
]
/
sum
[
i
]);
}
else
if
(
element_index
<
element_count
)
{
dst
[
i
*
element_count
*
stride
+
it
*
WARP_SIZE
]
=
0
;
}
else
{
break
;
}
}
}
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
log2_elements
>
__global__
void
scaled_upper_triang_masked_softmax_warp_backward
(
output_t
*
gradInput
,
input_t
*
grad
,
const
input_t
*
output
,
acc_t
scale
,
int
batch_size
,
int
stride
,
int
element_count
)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_backward_kernel.
constexpr
int
next_power_of_two
=
1
<<
log2_elements
;
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
y
+
threadIdx
.
y
)
*
gridDim
.
x
*
WARP_BATCH
+
blockIdx
.
x
;
int
local_seq
=
blockIdx
.
x
+
1
;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int
local_batches
=
batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
local_batches
=
WARP_BATCH
;
// there might be multiple batches per warp. compute the index within the batch
int
local_idx
=
threadIdx
.
x
;
// the first element to process by the current thread
int
thread_offset
=
first_batch
*
stride
+
local_idx
;
grad
+=
thread_offset
;
output
+=
thread_offset
;
gradInput
+=
thread_offset
;
// load data from global memory
acc_t
grad_reg
[
WARP_BATCH
][
WARP_ITERATIONS
]
{
0.0
f
};
acc_t
output_reg
[
WARP_BATCH
][
WARP_ITERATIONS
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
local_seq
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
output_reg
[
i
][
it
]
=
output
[
i
*
element_count
*
stride
+
it
*
WARP_SIZE
];
}
else
{
output_reg
[
i
][
it
]
=
acc_t
(
0
);
}
}
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
grad_reg
[
i
][
it
]
=
(
acc_t
)
grad
[
i
*
element_count
*
stride
+
it
*
WARP_SIZE
]
*
output_reg
[
i
][
it
];
}
else
{
grad_reg
[
i
][
it
]
=
acc_t
(
0
);
}
}
}
acc_t
sum
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
=
grad_reg
[
i
][
0
];
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
sum
[
i
]
+=
grad_reg
[
i
][
it
];
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Add
>
(
sum
);
// store result
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
// compute gradients
gradInput
[
i
*
element_count
*
stride
+
it
*
WARP_SIZE
]
=
(
output_t
)(
scale
*
(
grad_reg
[
i
][
it
]
-
output_reg
[
i
][
it
]
*
sum
[
i
]));
}
}
}
}
}
// end of anonymous namespace
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
void
dispatch_scaled_upper_triang_masked_softmax_forward
(
output_t
*
dst
,
const
input_t
*
src
,
const
input_t
scale
,
int
softmax_elements
,
int
softmax_elements_stride
,
int
attn_batches
)
{
TORCH_INTERNAL_ASSERT
(
softmax_elements
>=
0
&&
softmax_elements
<=
2048
);
if
(
softmax_elements
==
0
)
{
return
;
}
else
{
int
log2_elements
=
log2_ceil
(
softmax_elements
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
seq_len
=
softmax_elements
;
int
batch_count
=
attn_batches
*
seq_len
;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
TORCH_INTERNAL_ASSERT
(
attn_batches
%
batches_per_block
==
0
);
int
blocks_per_seq
=
attn_batches
/
batches_per_block
;
dim3
blocks
(
seq_len
,
blocks_per_seq
,
1
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch
(
log2_elements
)
{
case
0
:
// 1
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
0
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
1
:
// 2
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
1
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
2
:
// 4
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
3
:
// 8
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
3
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
4
:
// 16
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
4
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
5
:
// 32
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
5
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
6
:
// 64
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
6
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
7
:
// 128
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
7
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
8
:
// 256
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
8
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
9
:
// 512
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
9
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
10
:
// 1024
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
10
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
11
:
// 2048
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
11
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
default:
break
;
}
}
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
void
dispatch_scaled_upper_triang_masked_softmax_backward
(
output_t
*
grad_input
,
input_t
*
grad
,
const
input_t
*
output
,
const
acc_t
scale
,
int
softmax_elements
,
int
softmax_elements_stride
,
int
attn_batches
)
{
TORCH_INTERNAL_ASSERT
(
softmax_elements
>=
0
&&
softmax_elements
<=
2048
);
if
(
softmax_elements
==
0
)
{
return
;
}
else
{
int
log2_elements
=
log2_ceil
(
softmax_elements
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
seq_len
=
softmax_elements
;
int
batch_count
=
attn_batches
*
seq_len
;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
TORCH_INTERNAL_ASSERT
(
attn_batches
%
batches_per_block
==
0
);
int
blocks_per_seq
=
attn_batches
/
batches_per_block
;
dim3
blocks
(
seq_len
,
blocks_per_seq
,
1
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch
(
log2_elements
)
{
case
0
:
// 1
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
0
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
1
:
// 2
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
1
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
2
:
// 4
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
3
:
// 8
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
3
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
4
:
// 16
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
4
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
5
:
// 32
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
5
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
6
:
// 64
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
6
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
7
:
// 128
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
7
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
8
:
// 256
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
8
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
9
:
// 512
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
9
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
10
:
// 1024
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
10
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
11
:
// 2048
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
11
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
default:
break
;
}
}
}
megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu
0 → 100644
View file @
51a2e6b0
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "scaled_upper_triang_masked_softmax.h"
namespace
multihead_attn
{
namespace
fused_softmax
{
namespace
scaled_upper_triang_masked_softmax
{
torch
::
Tensor
fwd_cuda
(
torch
::
Tensor
const
&
input
,
float
scale_factor
)
{
// input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
const
int
attn_batches
=
input
.
size
(
0
);
const
int
seq_len
=
input
.
size
(
1
);
TORCH_INTERNAL_ASSERT
(
seq_len
<=
2048
);
// Output
auto
act_options
=
input
.
options
().
requires_grad
(
false
);
torch
::
Tensor
softmax_results
=
torch
::
empty
({
attn_batches
,
seq_len
,
seq_len
},
act_options
);
// Softmax Intermediate Result Ptr
void
*
input_ptr
=
static_cast
<
void
*>
(
input
.
data_ptr
());
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
dispatch_scaled_upper_triang_masked_softmax_forward
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
input_ptr
),
scale_factor
,
seq_len
,
seq_len
,
attn_batches
);
return
softmax_results
;
}
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
const
&
output_grads_
,
torch
::
Tensor
const
&
softmax_results_
,
float
scale_factor
)
{
auto
output_grads
=
output_grads_
.
contiguous
();
auto
softmax_results
=
softmax_results_
.
contiguous
();
//output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
const
int
attn_batches
=
output_grads
.
size
(
0
);
const
int
seq_len
=
output_grads
.
size
(
1
);
TORCH_INTERNAL_ASSERT
(
output_grads
.
size
(
1
)
==
output_grads
.
size
(
2
));
void
*
output_grads_ptr
=
static_cast
<
void
*>
(
output_grads
.
data_ptr
());
//Softmax Grad
dispatch_scaled_upper_triang_masked_softmax_backward
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
output_grads_ptr
),
reinterpret_cast
<
half
*>
(
output_grads_ptr
),
reinterpret_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
scale_factor
,
seq_len
,
seq_len
,
attn_batches
);
//backward pass is completely in-place
return
output_grads
;
}
}
}
}
megatron/model/fused_bias_gelu.py
0 → 100644
View file @
51a2e6b0
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
torch
.
_C
.
_jit_set_profiling_mode
(
False
)
torch
.
_C
.
_jit_set_profiling_executor
(
False
)
torch
.
_C
.
_jit_override_can_fuse_on_cpu
(
True
)
torch
.
_C
.
_jit_override_can_fuse_on_gpu
(
True
)
###### BIAS GELU FUSION/ NO AUTOGRAD ################
# 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2) -> 0.70710678
# sqrt(2/pi) -> 0.79788456
# this function is tanh approximation of gelu
# actual gelu is:
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@
torch
.
jit
.
script
def
bias_gelu
(
bias
,
y
):
x
=
bias
+
y
return
x
*
0.5
*
(
1.0
+
torch
.
tanh
(
0.79788456
*
x
*
(
1
+
0.044715
*
x
*
x
)))
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@
torch
.
jit
.
script
def
bias_gelu_back
(
g
,
bias
,
y
):
x
=
bias
+
y
tanh_out
=
torch
.
tanh
(
0.79788456
*
x
*
(
1
+
0.044715
*
x
*
x
))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff
=
0.5
*
x
*
((
1
-
tanh_out
*
tanh_out
)
*
(
0.79788456
+
0.1070322243
*
x
*
x
))
+
0.5
*
(
1
+
tanh_out
)
return
ff
*
g
class
GeLUFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
# bias is an optional argument
def
forward
(
ctx
,
input
,
bias
):
ctx
.
save_for_backward
(
input
,
bias
)
return
bias_gelu
(
bias
,
input
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
input
,
bias
=
ctx
.
saved_tensors
tmp
=
bias_gelu_back
(
grad_output
,
bias
,
input
)
return
tmp
,
tmp
bias_gelu_impl
=
GeLUFunction
.
apply
megatron/model/fused_softmax.py
0 → 100644
View file @
51a2e6b0
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
class
ScaledUpperTriangMaskedSoftmax
(
torch
.
autograd
.
Function
)
:
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply upper triangular mask (typically used in gpt models).
3. Perform softmax.
"""
@
staticmethod
def
forward
(
ctx
,
inputs
,
scale
):
import
scaled_upper_triang_masked_softmax_cuda
scale_t
=
torch
.
tensor
([
scale
])
softmax_results
=
\
scaled_upper_triang_masked_softmax_cuda
.
forward
(
inputs
,
scale_t
[
0
])
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
return
softmax_results
@
staticmethod
def
backward
(
ctx
,
output_grads
):
import
scaled_upper_triang_masked_softmax_cuda
softmax_results
,
scale_t
=
ctx
.
saved_tensors
input_grads
=
\
scaled_upper_triang_masked_softmax_cuda
.
backward
(
output_grads
,
softmax_results
,
scale_t
[
0
])
return
input_grads
,
None
class
FusedScaleMaskSoftmax
(
torch
.
nn
.
Module
):
"""
fused operation: scaling + mask + softmax
Arguments:
input_in_fp16: flag to indicate if input in fp16 data format.
upper_triang_mask: if true, apply upper triangular masking.
(used in gpt family networks)
mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
scale: scaling factor used in input tensor scaling.
"""
def
__init__
(
self
,
input_in_fp16
,
upper_triang_mask
,
mask_func
,
softmax_in_fp32
,
scale
):
super
(
FusedScaleMaskSoftmax
,
self
).
__init__
()
self
.
input_in_fp16
=
input_in_fp16
self
.
upper_triang_mask
=
upper_triang_mask
self
.
mask_func
=
mask_func
self
.
softmax_in_fp32
=
softmax_in_fp32
self
.
scale
=
scale
assert
self
.
scale
is
None
or
softmax_in_fp32
,
\
'softmax should be in fp32 when scaled'
def
forward
(
self
,
input
,
mask
):
# [b, np, s, s]
data_size
=
input
.
size
()
assert
input
.
dim
()
==
4
# invoke custom kernel for implicit uuper triangular masking
if
self
.
input_in_fp16
and
self
.
upper_triang_mask
and
\
data_size
[
-
1
]
<=
2048
and
input
.
size
()[
2
]
==
input
.
size
()[
3
]:
input
=
input
.
view
(
-
1
,
data_size
[
2
],
data_size
[
3
])
scale
=
self
.
scale
if
self
.
scale
is
not
None
else
1.0
probs
=
ScaledUpperTriangMaskedSoftmax
.
apply
(
input
,
scale
)
probs
=
probs
.
view
(
*
data_size
)
else
:
if
self
.
input_in_fp16
and
self
.
softmax_in_fp32
:
input
=
input
.
float
()
mask_output
=
self
.
mask_func
(
input
,
mask
)
if
self
.
scale
is
not
None
:
mask_output
=
mask_output
*
self
.
scale
probs
=
torch
.
nn
.
Softmax
(
dim
=-
1
)(
mask_output
)
if
self
.
input_in_fp16
and
self
.
softmax_in_fp32
:
probs
=
probs
.
half
()
return
probs
megatron/model/language_model.py
View file @
51a2e6b0
...
...
@@ -22,7 +22,6 @@ from megatron import get_args
from
megatron
import
mpu
from
megatron.module
import
MegatronModule
from
megatron.model.transformer
import
ParallelTransformer
from
megatron.model.utils
import
openai_gelu
,
erf_gelu
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.utils
import
init_method_normal
,
scaled_init_method_normal
...
...
@@ -48,13 +47,6 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
"""Build language model and return along with the key to save."""
args
=
get_args
()
# Use torch gelu unless otherwise forced.
gelu
=
F
.
gelu
if
args
.
openai_gelu
:
gelu
=
openai_gelu
elif
args
.
onnx_safe
:
gelu
=
erf_gelu
if
init_method
is
None
:
init_method
=
init_method_normal
(
args
.
init_method_std
)
...
...
@@ -64,7 +56,6 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
# Language model.
language_model
=
TransformerLanguageModel
(
attention_mask_func
=
attention_mask_func
,
mlp_activation_func
=
gelu
,
init_method
=
init_method
,
output_layer_init_method
=
scaled_init_method
,
num_tokentypes
=
num_tokentypes
,
...
...
@@ -271,7 +262,6 @@ class TransformerLanguageModel(MegatronModule):
def
__init__
(
self
,
attention_mask_func
,
mlp_activation_func
,
init_method
,
output_layer_init_method
,
num_tokentypes
=
0
,
...
...
@@ -295,8 +285,8 @@ class TransformerLanguageModel(MegatronModule):
# Transformer
self
.
transformer
=
ParallelTransformer
(
attention_mask_func
,
mlp_activation_func
,
self
.
init_method
,
output_layer_init_method
)
attention_mask_func
,
self
.
init_method
,
output_layer_init_method
)
self
.
_transformer_key
=
'transformer'
# Pooler
...
...
megatron/model/transformer.py
View file @
51a2e6b0
This diff is collapsed.
Click to expand it.
megatron/mpu/layers.py
View file @
51a2e6b0
...
...
@@ -54,7 +54,7 @@ def _initialize_affine_weight_gpu(weight, init_method,
weight
.
model_parallel
=
True
weight
.
partition_dim
=
partition_dim
weight
.
partition_stride
=
stride
with
get_cuda_rng_tracker
().
fork
():
init_method
(
weight
)
...
...
@@ -186,11 +186,15 @@ class ColumnParallelLinear(torch.nn.Module):
keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights
used for initialization.
skip_bias_add: This was added to enable performance optimations where bias
can be fused with other elementwise operations. we skip
adding bias but instead return it.
"""
def
__init__
(
self
,
input_size
,
output_size
,
bias
=
True
,
gather_output
=
True
,
init_method
=
init
.
xavier_normal_
,
stride
=
1
,
keep_master_weight_for_test
=
False
):
keep_master_weight_for_test
=
False
,
skip_bias_add
=
False
):
super
(
ColumnParallelLinear
,
self
).
__init__
()
# Keep input parameters
...
...
@@ -200,6 +204,7 @@ class ColumnParallelLinear(torch.nn.Module):
# Divide the weight matrix along the last dimension.
world_size
=
get_model_parallel_world_size
()
self
.
output_size_per_partition
=
divide
(
output_size
,
world_size
)
self
.
skip_bias_add
=
skip_bias_add
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
...
...
@@ -245,13 +250,16 @@ class ColumnParallelLinear(torch.nn.Module):
# Set up backprop all-reduce.
input_parallel
=
copy_to_model_parallel_region
(
input_
)
# Matrix multiply.
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
,
self
.
bias
)
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
,
bias
)
if
self
.
gather_output
:
# All-gather across the partitions.
output
=
gather_from_model_parallel_region
(
output_parallel
)
else
:
output
=
output_parallel
return
output
output
=
output_parallel
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
return
output
,
output_bias
class
RowParallelLinear
(
torch
.
nn
.
Module
):
...
...
@@ -279,12 +287,16 @@ class RowParallelLinear(torch.nn.Module):
keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights
used for initialization.
skip_bias_add: This was added to enable performance optimations where bias
can be fused with other elementwise operations. we skip
adding bias but instead return it.
"""
def
__init__
(
self
,
input_size
,
output_size
,
bias
=
True
,
input_is_parallel
=
False
,
init_method
=
init
.
xavier_normal_
,
stride
=
1
,
keep_master_weight_for_test
=
False
):
keep_master_weight_for_test
=
False
,
skip_bias_add
=
False
):
super
(
RowParallelLinear
,
self
).
__init__
()
# Keep input parameters
...
...
@@ -294,6 +306,7 @@ class RowParallelLinear(torch.nn.Module):
# Divide the weight matrix along the last dimension.
world_size
=
get_model_parallel_world_size
()
self
.
input_size_per_partition
=
divide
(
input_size
,
world_size
)
self
.
skip_bias_add
=
skip_bias_add
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
...
...
@@ -340,8 +353,11 @@ class RowParallelLinear(torch.nn.Module):
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
)
# All-reduce across all the partitions.
output_
=
reduce_from_model_parallel_region
(
output_parallel
)
if
self
.
bias
is
not
None
:
output
=
output_
+
self
.
bias
if
not
self
.
skip_bias_add
:
output
=
output_
+
self
.
bias
if
self
.
bias
is
not
None
else
output_
output_bias
=
None
else
:
output
=
output_
return
output
output_bias
=
self
.
bias
return
output
,
output_bias
megatron/training.py
View file @
51a2e6b0
...
...
@@ -236,29 +236,35 @@ def backward_step(optimizer, model, loss):
timers
=
get_timers
()
# Backward pass.
timers
(
'backward-backward'
).
start
()
optimizer
.
zero_grad
(
set_grads_to_None
=
True
)
if
args
.
fp16
:
optimizer
.
backward
(
loss
,
update_master_grads
=
False
)
else
:
loss
.
backward
()
timers
(
'backward-backward'
).
stop
()
# All-reduce if needed.
if
args
.
DDP_impl
==
'local'
:
timers
(
'allreduce'
).
start
()
timers
(
'
backward-
allreduce'
).
start
()
model
.
allreduce_params
(
reduce_after
=
False
,
fp32_allreduce
=
args
.
fp32_allreduce
)
timers
(
'allreduce'
).
stop
()
timers
(
'
backward-
allreduce'
).
stop
()
# Update master gradients.
timers
(
'backward-master-grad'
).
start
()
if
args
.
fp16
:
optimizer
.
update_master_grads
()
timers
(
'backward-master-grad'
).
stop
()
# Clipping gradients helps prevent the exploding gradient.
timers
(
'backward-clip-grad'
).
start
()
if
args
.
clip_grad
>
0
:
if
not
args
.
fp16
:
mpu
.
clip_grad_norm
(
model
.
parameters
(),
args
.
clip_grad
)
else
:
optimizer
.
clip_master_grads
(
args
.
clip_grad
)
timers
(
'backward-clip-grad'
).
stop
()
def
train_step
(
forward_step_func
,
data_iterator
,
...
...
@@ -311,7 +317,10 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
timers_to_log
.
append
(
name
)
add_to_logging
(
'forward'
)
add_to_logging
(
'backward'
)
add_to_logging
(
'allreduce'
)
add_to_logging
(
'backward-backward'
)
add_to_logging
(
'backward-allreduce'
)
add_to_logging
(
'backward-master-grad'
)
add_to_logging
(
'backward-clip-grad'
)
add_to_logging
(
'optimizer'
)
add_to_logging
(
'batch generator'
)
...
...
megatron/utils.py
View file @
51a2e6b0
...
...
@@ -46,9 +46,9 @@ def report_memory(name):
torch
.
cuda
.
memory_allocated
()
/
mega_bytes
)
string
+=
' | max allocated: {}'
.
format
(
torch
.
cuda
.
max_memory_allocated
()
/
mega_bytes
)
string
+=
' |
cach
ed: {}'
.
format
(
torch
.
cuda
.
memory_
cach
ed
()
/
mega_bytes
)
string
+=
' | max
cach
ed: {}'
.
format
(
torch
.
cuda
.
max_memory_
cach
ed
()
/
mega_bytes
)
string
+=
' |
reserv
ed: {}'
.
format
(
torch
.
cuda
.
memory_
reserv
ed
()
/
mega_bytes
)
string
+=
' | max
reserv
ed: {}'
.
format
(
torch
.
cuda
.
max_memory_
reserv
ed
()
/
mega_bytes
)
print_rank_0
(
string
)
...
...
pretrain_gpt2.py
View file @
51a2e6b0
...
...
@@ -28,7 +28,6 @@ from megatron.training import pretrain
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
megatron.utils
import
reduce_losses
def
model_provider
():
"""Build the model."""
...
...
@@ -80,10 +79,8 @@ def forward_step(data_iterator, model):
tokens
,
labels
,
loss_mask
,
attention_mask
,
position_ids
=
get_batch
(
data_iterator
)
timers
(
'batch generator'
).
stop
()
# Forward model.
losses
=
model
(
tokens
,
position_ids
,
attention_mask
,
labels
=
labels
)
loss_mask
=
loss_mask
.
view
(
-
1
)
loss
=
torch
.
sum
(
losses
.
view
(
-
1
)
*
loss_mask
)
/
loss_mask
.
sum
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment