Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
a5acbf53
Commit
a5acbf53
authored
Mar 19, 2021
by
Mostofa Patwary
Browse files
Merge branch 'main' into main_retriver_merge_ict_eval
parents
40565390
a6e00d97
Changes
19
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
590 additions
and
751 deletions
+590
-751
megatron/arguments.py
megatron/arguments.py
+1
-28
megatron/fused_kernels/__init__.py
megatron/fused_kernels/__init__.py
+70
-87
megatron/fused_kernels/layer_norm_cuda.cpp
megatron/fused_kernels/layer_norm_cuda.cpp
+32
-91
megatron/fused_kernels/layer_norm_cuda_kernel.cu
megatron/fused_kernels/layer_norm_cuda_kernel.cu
+33
-33
megatron/fused_kernels/scaled_masked_softmax.cpp
megatron/fused_kernels/scaled_masked_softmax.cpp
+9
-6
megatron/fused_kernels/scaled_masked_softmax.h
megatron/fused_kernels/scaled_masked_softmax.h
+86
-42
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
+28
-20
megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp
...tron/fused_kernels/scaled_upper_triang_masked_softmax.cpp
+10
-7
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
+109
-37
megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu
.../fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu
+25
-16
megatron/fused_kernels/type_shim.h
megatron/fused_kernels/type_shim.h
+69
-205
megatron/initialize.py
megatron/initialize.py
+71
-12
megatron/model/__init__.py
megatron/model/__init__.py
+1
-17
megatron/model/bert_model.py
megatron/model/bert_model.py
+1
-2
megatron/model/fused_layer_norm.py
megatron/model/fused_layer_norm.py
+28
-117
megatron/model/fused_softmax.py
megatron/model/fused_softmax.py
+13
-5
megatron/model/transformer.py
megatron/model/transformer.py
+3
-13
megatron/optimizer/__init__.py
megatron/optimizer/__init__.py
+1
-3
megatron/training.py
megatron/training.py
+0
-10
No files found.
megatron/arguments.py
View file @
a5acbf53
...
@@ -19,7 +19,6 @@ import argparse
...
@@ -19,7 +19,6 @@ import argparse
import
os
import
os
import
torch
import
torch
from
megatron
import
fused_kernels
def
parse_args
(
extra_args_provider
=
None
,
defaults
=
{},
def
parse_args
(
extra_args_provider
=
None
,
defaults
=
{},
ignore_unknown_args
=
False
):
ignore_unknown_args
=
False
):
...
@@ -134,8 +133,7 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -134,8 +133,7 @@ def parse_args(extra_args_provider=None, defaults={},
if
args
.
bf16
:
if
args
.
bf16
:
assert
not
args
.
fp16
assert
not
args
.
fp16
args
.
params_dtype
=
torch
.
bfloat16
args
.
params_dtype
=
torch
.
bfloat16
# No fusion is support for bfloat for now
# Jitting fusion is not supported for bfloat for now
assert
not
args
.
masked_softmax_fusion
assert
not
args
.
bias_gelu_fusion
assert
not
args
.
bias_gelu_fusion
assert
not
args
.
bias_dropout_fusion
assert
not
args
.
bias_dropout_fusion
...
@@ -227,31 +225,6 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -227,31 +225,6 @@ def parse_args(extra_args_provider=None, defaults={},
'for distribute-checkpointed-activations to work you '
\
'for distribute-checkpointed-activations to work you '
\
'need to enable checkpoint-activations'
'need to enable checkpoint-activations'
# custom kernel constraints check
seq_len
=
args
.
seq_length
attn_batch_size
=
\
(
args
.
num_attention_heads
/
args
.
tensor_model_parallel_size
)
*
\
args
.
micro_batch_size
# constraints on sequence length and attn_batch_size to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint
=
seq_len
>
16
and
seq_len
<=
2048
and
\
seq_len
%
4
==
0
and
attn_batch_size
%
4
==
0
if
not
(
args
.
fp16
and
custom_kernel_constraint
and
args
.
masked_softmax_fusion
):
print
(
'WARNING: constraints for invoking optimized'
' fused softmax kernel are not met. We default back to unfused'
' kernel invocations.'
)
# Load scaled_masked_softmax_fusion_kernels
if
args
.
masked_softmax_fusion
:
fused_kernels
.
load_scaled_upper_triang_masked_softmax_fusion_kernel
()
fused_kernels
.
load_scaled_masked_softmax_fusion_kernel
()
# Load mixed precision fused layer norm.
if
args
.
fp32_residual_connection
:
fused_kernels
.
load_fused_mix_prec_layer_norm_kernel
()
_print_args
(
args
)
_print_args
(
args
)
return
args
return
args
...
...
megatron/fused_kernels/__init__.py
View file @
a5acbf53
...
@@ -13,114 +13,97 @@
...
@@ -13,114 +13,97 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
os
import
pathlib
import
pathlib
import
subprocess
import
subprocess
import
os
from
torch.utils
import
cpp_extension
from
torch.utils
import
cpp_extension
# Setting this param to a list has a problem of generating
# Setting this param to a list has a problem of generating
different
#
different
compilation commands (with diferent order of architectures)
# compilation commands (with diferent order of architectures)
and
#
and
leading to recompilation of fused kernels.
# leading to recompilation of fused kernels.
Set it to empty string
#
set it to empty string to avoid recompilatio
n
#
to avoid recompilation and assign arch flags explicity i
n
#
and assign arch flags explicity in
extra_cuda_cflags below
# extra_cuda_cflags below
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
""
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
""
def
get_cuda_bare_metal_version
(
cuda_dir
):
raw_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
output
=
raw_output
.
split
()
release_idx
=
output
.
index
(
"release"
)
+
1
release
=
output
[
release_idx
].
split
(
"."
)
bare_metal_major
=
release
[
0
]
bare_metal_minor
=
release
[
1
][
0
]
return
raw_output
,
bare_metal_major
,
bare_metal_minor
def
create_build_dir
(
buildpath
):
try
:
os
.
mkdir
(
buildpath
)
except
OSError
:
if
not
os
.
path
.
isdir
(
buildpath
):
print
(
f
"Creation of the build directory
{
buildpath
}
failed"
)
def
load
_scaled_upper_triang_masked_softmax_fusion_kernel
(
):
def
load
(
args
):
# Check
,
if
CUDA
11 is installed for compute capability 8.0
# Check if
cuda
11 is installed for compute capability 8.0
cc_flag
=
[]
cc_flag
=
[]
_
,
bare_metal_major
,
_
=
get_cuda_bare_metal_version
(
cpp_extension
.
CUDA_HOME
)
_
,
bare_metal_major
,
_
=
_get_cuda_bare_metal_version
(
cpp_extension
.
CUDA_HOME
)
if
int
(
bare_metal_major
)
>=
11
:
if
int
(
bare_metal_major
)
>=
11
:
cc_flag
.
append
(
'-gencode'
)
cc_flag
.
append
(
'-gencode'
)
cc_flag
.
append
(
'arch=compute_80,code=sm_80'
)
cc_flag
.
append
(
'arch=compute_80,code=sm_80'
)
# Build path
srcpath
=
pathlib
.
Path
(
__file__
).
parent
.
absolute
()
srcpath
=
pathlib
.
Path
(
__file__
).
parent
.
absolute
()
buildpath
=
srcpath
/
'build'
buildpath
=
srcpath
/
'build'
_create_build_dir
(
buildpath
)
create_build_dir
(
buildpath
)
# Helper function to build the kernels.
scaled_upper_triang_masked_softmax_cuda
=
cpp_extension
.
load
(
def
_cpp_extention_load_helper
(
name
,
sources
,
extra_cuda_flags
):
name
=
'scaled_upper_triang_masked_softmax_cuda'
,
return
cpp_extension
.
load
(
name
=
name
,
sources
=
sources
,
build_directory
=
buildpath
,
extra_cflags
=
[
'-O3'
,],
extra_cuda_cflags
=
[
'-O3'
,
'-gencode'
,
'arch=compute_70,code=sm_70'
,
'--use_fast_math'
]
+
extra_cuda_flags
+
cc_flag
,
verbose
=
(
args
.
rank
==
0
)
)
# ==============
# Fused softmax.
# ==============
if
args
.
masked_softmax_fusion
:
extra_cuda_flags
=
[
'-U__CUDA_NO_HALF_OPERATORS__'
,
'-U__CUDA_NO_HALF_CONVERSIONS__'
,
'--expt-relaxed-constexpr'
,
'--expt-extended-lambda'
]
# Upper triangular softmax.
sources
=
[
srcpath
/
'scaled_upper_triang_masked_softmax.cpp'
,
sources
=
[
srcpath
/
'scaled_upper_triang_masked_softmax.cpp'
,
srcpath
/
'scaled_upper_triang_masked_softmax_cuda.cu'
],
srcpath
/
'scaled_upper_triang_masked_softmax_cuda.cu'
]
build_directory
=
buildpath
,
scaled_upper_triang_masked_softmax_cuda
=
_cpp_extention_load_helper
(
extra_cflags
=
[
'-O3'
,],
"scaled_upper_triang_masked_softmax_cuda"
,
extra_cuda_cflags
=
[
'-O3'
,
sources
,
extra_cuda_flags
)
'-gencode'
,
'arch=compute_70,code=sm_70'
,
'-U__CUDA_NO_HALF_OPERATORS__'
,
'-U__CUDA_NO_HALF_CONVERSIONS__'
,
'--expt-relaxed-constexpr'
,
'--expt-extended-lambda'
,
'--use_fast_math'
]
+
cc_flag
)
def
load_scaled_masked_softmax_fusion_kernel
():
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag
=
[]
_
,
bare_metal_major
,
_
=
get_cuda_bare_metal_version
(
cpp_extension
.
CUDA_HOME
)
if
int
(
bare_metal_major
)
>=
11
:
cc_flag
.
append
(
'-gencode'
)
cc_flag
.
append
(
'arch=compute_80,code=sm_80'
)
srcpath
=
pathlib
.
Path
(
__file__
).
parent
.
absolute
()
# Masked softmax.
buildpath
=
srcpath
/
'build'
sources
=
[
srcpath
/
'scaled_masked_softmax.cpp'
,
srcpath
/
'scaled_masked_softmax_cuda.cu'
]
scaled_masked_softmax_cuda
=
_cpp_extention_load_helper
(
"scaled_masked_softmax_cuda"
,
sources
,
extra_cuda_flags
)
create_build_dir
(
buildpath
)
# =================================
# Mixed precision fused layer norm.
# =================================
scaled_upper_triang_masked_softmax_cuda
=
cpp_extension
.
load
(
extra_cuda_flags
=
[
'-maxrregcount=50'
]
name
=
'scaled_masked_softmax_cuda'
,
sources
=
[
srcpath
/
'layer_norm_cuda.cpp'
,
sources
=
[
srcpath
/
'scaled_masked_softmax.cpp'
,
srcpath
/
'layer_norm_cuda_kernel.cu'
]
srcpath
/
'scaled_masked_softmax_cuda.cu'
],
fused_mix_prec_layer_norm_cuda
=
_cpp_extention_load_helper
(
build_directory
=
buildpath
,
"fused_mix_prec_layer_norm_cuda"
,
sources
,
extra_cuda_flags
)
extra_cflags
=
[
'-O3'
,],
extra_cuda_cflags
=
[
'-O3'
,
'-gencode'
,
'arch=compute_70,code=sm_70'
,
'-U__CUDA_NO_HALF_OPERATORS__'
,
'-U__CUDA_NO_HALF_CONVERSIONS__'
,
'--expt-relaxed-constexpr'
,
'--expt-extended-lambda'
,
'--use_fast_math'
]
+
cc_flag
)
def
load_fused_mix_prec_layer_norm_kernel
():
def
_get_cuda_bare_metal_version
(
cuda_dir
):
raw_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
output
=
raw_output
.
split
()
release_idx
=
output
.
index
(
"release"
)
+
1
release
=
output
[
release_idx
].
split
(
"."
)
bare_metal_major
=
release
[
0
]
bare_metal_minor
=
release
[
1
][
0
]
# Check, if CUDA11 is installed for compute capability 8.0
return
raw_output
,
bare_metal_major
,
bare_metal_minor
cc_flag
=
[]
_
,
bare_metal_major
,
_
=
get_cuda_bare_metal_version
(
cpp_extension
.
CUDA_HOME
)
if
int
(
bare_metal_major
)
>=
11
:
cc_flag
.
append
(
'-gencode'
)
cc_flag
.
append
(
'arch=compute_80,code=sm_80'
)
srcpath
=
pathlib
.
Path
(
__file__
).
parent
.
absolute
()
buildpath
=
srcpath
/
'build'
create_build_dir
(
buildpath
)
def
_create_build_dir
(
buildpath
):
try
:
fused_mix_prec_layer_norm_cuda
=
cpp_extension
.
load
(
os
.
mkdir
(
buildpath
)
name
=
'fused_mix_prec_layer_norm_cuda'
,
except
OSError
:
sources
=
[
srcpath
/
'layer_norm_cuda.cpp'
,
if
not
os
.
path
.
isdir
(
buildpath
):
srcpath
/
'layer_norm_cuda_kernel.cu'
],
print
(
f
"Creation of the build directory
{
buildpath
}
failed"
)
build_directory
=
buildpath
,
extra_cflags
=
[
'-O3'
],
extra_cuda_cflags
=
[
'-O3'
,
'-gencode'
,
'arch=compute_70,code=sm_70'
,
'-maxrregcount=50'
,
'--use_fast_math'
]
+
cc_flag
)
megatron/fused_kernels/layer_norm_cuda.cpp
View file @
a5acbf53
...
@@ -24,16 +24,12 @@
...
@@ -24,16 +24,12 @@
#include "compat.h"
#include "compat.h"
namespace
{
namespace
{
void
compute_n1_n2
(
void
compute_n1_n2
(
at
::
Tensor
input
,
at
::
Tensor
input
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
int
&
n1
,
int
&
n1
,
int
&
n2
)
int
&
n2
)
{
{
int
idiff
=
input
.
ndimension
()
-
normalized_shape
.
size
();
int
idiff
=
input
.
ndimension
()
-
normalized_shape
.
size
();
n2
=
1
;
n2
=
1
;
for
(
int
i
=
0
;
i
<
(
int
)
normalized_shape
.
size
();
++
i
)
{
for
(
int
i
=
0
;
i
<
(
int
)
normalized_shape
.
size
();
++
i
)
{
...
@@ -47,11 +43,7 @@ void compute_n1_n2(
...
@@ -47,11 +43,7 @@ void compute_n1_n2(
}
}
void
check_args
(
void
check_args
(
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
at
::
Tensor
gamma
,
at
::
Tensor
gamma
,
at
::
Tensor
beta
at
::
Tensor
beta
)
)
...
@@ -62,11 +54,7 @@ void check_args(
...
@@ -62,11 +54,7 @@ void check_args(
void
check_args
(
void
check_args
(
at
::
Tensor
input
,
at
::
Tensor
input
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
int
&
n1
,
int
&
n1
,
int
&
n2
int
&
n2
)
)
...
@@ -102,11 +90,7 @@ void check_args(
...
@@ -102,11 +90,7 @@ void check_args(
void
check_args
(
void
check_args
(
at
::
Tensor
input
,
at
::
Tensor
input
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
at
::
Tensor
gamma
,
at
::
Tensor
gamma
,
at
::
Tensor
beta
,
at
::
Tensor
beta
,
int
&
n1
,
int
&
n1
,
...
@@ -125,60 +109,42 @@ void cuda_layer_norm(
...
@@ -125,60 +109,42 @@ void cuda_layer_norm(
at
::
Tensor
*
input
,
at
::
Tensor
*
input
,
int
n1
,
int
n1
,
int
n2
,
int
n2
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
at
::
Tensor
*
gamma
,
at
::
Tensor
*
gamma
,
at
::
Tensor
*
beta
,
at
::
Tensor
*
beta
,
double
epsilon
);
double
epsilon
);
#define CHECK_CUDA(x) TORCH_CHECK(x.
type().
is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std
::
vector
<
at
::
Tensor
>
layer_norm
(
at
::
Tensor
input
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
double
epsilon
)
{
CHECK_INPUT
(
input
);
int
n1
,
n2
;
check_args
(
input
,
normalized_shape
,
n1
,
n2
);
at
::
Tensor
output
=
at
::
empty_like
(
input
);
at
::
Tensor
mean
=
at
::
empty
({
n1
},
input
.
options
().
dtype
(
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
?
at
::
ScalarType
::
Float
:
input
.
scalar_type
()));
at
::
Tensor
invvar
=
at
::
empty_like
(
mean
);
cuda_layer_norm
(
&
output
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
normalized_shape
,
NULL
,
NULL
,
epsilon
);
return
{
output
,
mean
,
invvar
};
}
std
::
vector
<
at
::
Tensor
>
layer_norm_affine
(
std
::
vector
<
at
::
Tensor
>
layer_norm_affine
(
at
::
Tensor
input
,
at
::
Tensor
input
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
at
::
Tensor
gamma
,
at
::
Tensor
gamma
,
at
::
Tensor
beta
,
at
::
Tensor
beta
,
double
epsilon
)
{
double
epsilon
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input
);
CHECK_INPUT
(
gamma
);
CHECK_INPUT
(
gamma
);
CHECK_INPUT
(
beta
);
CHECK_INPUT
(
beta
);
int
n1
,
n2
;
int
n1
,
n2
;
check_args
(
input
,
normalized_shape
,
gamma
,
beta
,
n1
,
n2
);
check_args
(
input
,
normalized_shape
,
gamma
,
beta
,
n1
,
n2
);
at
::
Tensor
output
=
at
::
empty_like
(
input
,
input
.
options
().
dtype
(
at
::
ScalarType
::
Half
));
at
::
Tensor
mean
=
at
::
empty
({
n1
},
input
.
options
().
dtype
(
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
?
at
::
ScalarType
::
Float
:
input
.
scalar_type
()));
at
::
Tensor
output
=
at
::
empty_like
(
input
,
gamma
.
options
().
dtype
(
gamma
.
scalar_type
()));
at
::
Tensor
mean
=
at
::
empty
(
{
n1
},
input
.
options
().
dtype
(
at
::
ScalarType
::
Float
));
at
::
Tensor
invvar
=
at
::
empty_like
(
mean
);
at
::
Tensor
invvar
=
at
::
empty_like
(
mean
);
cuda_layer_norm
(
&
output
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
normalized_shape
,
&
gamma
,
&
beta
,
epsilon
);
cuda_layer_norm
(
&
output
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
normalized_shape
,
&
gamma
,
&
beta
,
epsilon
);
return
{
output
,
mean
,
invvar
};
return
{
output
,
mean
,
invvar
};
}
}
void
cuda_layer_norm_gradient
(
void
cuda_layer_norm_gradient
(
at
::
Tensor
*
dout
,
at
::
Tensor
*
dout
,
at
::
Tensor
*
mean
,
at
::
Tensor
*
mean
,
...
@@ -186,11 +152,7 @@ void cuda_layer_norm_gradient(
...
@@ -186,11 +152,7 @@ void cuda_layer_norm_gradient(
at
::
Tensor
*
input
,
at
::
Tensor
*
input
,
int
n1
,
int
n1
,
int
n2
,
int
n2
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
at
::
Tensor
*
gamma
,
at
::
Tensor
*
gamma
,
at
::
Tensor
*
beta
,
at
::
Tensor
*
beta
,
double
epsilon
,
double
epsilon
,
...
@@ -199,62 +161,41 @@ void cuda_layer_norm_gradient(
...
@@ -199,62 +161,41 @@ void cuda_layer_norm_gradient(
at
::
Tensor
*
grad_beta
at
::
Tensor
*
grad_beta
);
);
at
::
Tensor
layer_norm_gradient
(
at
::
Tensor
dout
,
at
::
Tensor
mean
,
at
::
Tensor
invvar
,
at
::
Tensor
input
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
double
epsilon
)
{
CHECK_INPUT
(
dout
);
CHECK_INPUT
(
mean
);
CHECK_INPUT
(
invvar
);
CHECK_INPUT
(
input
);
int
n1
,
n2
;
check_args
(
input
,
normalized_shape
,
n1
,
n2
);
at
::
Tensor
grad_input
=
at
::
empty_like
(
input
);
cuda_layer_norm_gradient
(
&
dout
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
normalized_shape
,
NULL
,
NULL
,
epsilon
,
&
grad_input
,
NULL
,
NULL
);
return
grad_input
;
}
std
::
vector
<
at
::
Tensor
>
layer_norm_gradient_affine
(
std
::
vector
<
at
::
Tensor
>
layer_norm_gradient_affine
(
at
::
Tensor
dout
,
at
::
Tensor
dout
,
at
::
Tensor
mean
,
at
::
Tensor
mean
,
at
::
Tensor
invvar
,
at
::
Tensor
invvar
,
at
::
Tensor
input
,
at
::
Tensor
input
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
at
::
Tensor
gamma
,
at
::
Tensor
gamma
,
at
::
Tensor
beta
,
at
::
Tensor
beta
,
double
epsilon
)
{
double
epsilon
)
{
CHECK_INPUT
(
dout
);
CHECK_INPUT
(
dout
);
CHECK_INPUT
(
mean
);
CHECK_INPUT
(
mean
);
CHECK_INPUT
(
invvar
);
CHECK_INPUT
(
invvar
);
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input
);
CHECK_INPUT
(
gamma
);
CHECK_INPUT
(
gamma
);
CHECK_INPUT
(
beta
);
CHECK_INPUT
(
beta
);
int
n1
,
n2
;
int
n1
,
n2
;
check_args
(
input
,
normalized_shape
,
gamma
,
beta
,
n1
,
n2
);
check_args
(
input
,
normalized_shape
,
gamma
,
beta
,
n1
,
n2
);
at
::
Tensor
grad_input
=
at
::
empty_like
(
input
);
at
::
Tensor
grad_input
=
at
::
empty_like
(
input
);
at
::
Tensor
grad_gamma
=
at
::
empty_like
(
gamma
);
at
::
Tensor
grad_gamma
=
at
::
empty_like
(
gamma
);
at
::
Tensor
grad_beta
=
at
::
empty_like
(
beta
);
at
::
Tensor
grad_beta
=
at
::
empty_like
(
beta
);
cuda_layer_norm_gradient
(
&
dout
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
normalized_shape
,
&
gamma
,
&
beta
,
epsilon
,
cuda_layer_norm_gradient
(
&
dout
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
&
grad_input
,
&
grad_gamma
,
&
grad_beta
);
normalized_shape
,
&
gamma
,
&
beta
,
epsilon
,
&
grad_input
,
&
grad_gamma
,
&
grad_beta
);
return
{
grad_input
,
grad_gamma
,
grad_beta
};
return
{
grad_input
,
grad_gamma
,
grad_beta
};
}
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward_affine"
,
&
layer_norm_affine
,
"LayerNorm forward (CUDA)"
);
m
.
def
(
"forward_affine"
,
&
layer_norm_affine
,
m
.
def
(
"forward"
,
&
layer_norm
,
"LayerNorm forward (CUDA)"
);
"LayerNorm forward (CUDA)"
);
m
.
def
(
"backward_affine"
,
&
layer_norm_gradient_affine
,
"LayerNorm backward (CUDA)"
);
m
.
def
(
"backward_affine"
,
&
layer_norm_gradient_affine
,
m
.
def
(
"backward"
,
&
layer_norm_gradient
,
"LayerNorm backward (CUDA)"
);
"LayerNorm backward (CUDA)"
);
}
}
megatron/fused_kernels/layer_norm_cuda_kernel.cu
View file @
a5acbf53
...
@@ -285,15 +285,6 @@ struct SharedMemory <float>
...
@@ -285,15 +285,6 @@ struct SharedMemory <float>
}
}
};
};
template
<
>
struct
SharedMemory
<
double
>
{
__device__
double
*
getPointer
()
{
extern
__shared__
double
s_double
[];
return
s_double
;
}
};
}
}
template
<
typename
T
,
typename
U
,
typename
V
>
__global__
template
<
typename
T
,
typename
U
,
typename
V
>
__global__
...
@@ -656,6 +647,9 @@ void cuComputeGradInput(
...
@@ -656,6 +647,9 @@ void cuComputeGradInput(
}
}
}
}
template
<
typename
T
,
typename
U
,
typename
V
>
template
<
typename
T
,
typename
U
,
typename
V
>
void
HostApplyLayerNorm
(
void
HostApplyLayerNorm
(
V
*
output
,
V
*
output
,
...
@@ -671,7 +665,8 @@ void HostApplyLayerNorm(
...
@@ -671,7 +665,8 @@ void HostApplyLayerNorm(
{
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
dim3
threads
(
32
,
4
,
1
);
const
dim3
threads
(
32
,
4
,
1
);
const
uint64_t
maxGridY
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxGridSize
[
1
];
const
uint64_t
maxGridY
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxGridSize
[
1
];
const
dim3
blocks
(
1
,
std
::
min
((
uint64_t
)
n1
,
maxGridY
),
1
);
const
dim3
blocks
(
1
,
std
::
min
((
uint64_t
)
n1
,
maxGridY
),
1
);
int
nshared
=
int
nshared
=
threads
.
y
>
1
?
threads
.
y
>
1
?
...
@@ -687,6 +682,7 @@ void HostApplyLayerNorm(
...
@@ -687,6 +682,7 @@ void HostApplyLayerNorm(
gamma
,
beta
);
gamma
,
beta
);
}
}
void
cuda_layer_norm
(
void
cuda_layer_norm
(
at
::
Tensor
*
output
,
at
::
Tensor
*
output
,
at
::
Tensor
*
mean
,
at
::
Tensor
*
mean
,
...
@@ -704,21 +700,21 @@ void cuda_layer_norm(
...
@@ -704,21 +700,21 @@ void cuda_layer_norm(
double
epsilon
)
double
epsilon
)
{
{
using
namespace
at
;
using
namespace
at
;
DISPATCH_DOUBLE_FLOAT_AND_HALF
(
input
->
scalar_type
(),
0
,
"layer_norm_cuda_kernel"
,
DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES
(
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
input
->
scalar_type
(),
output
->
scalar_type
(),
"cuda_layer_norm_kernel"
,
using
output_t
=
at
::
Half
;
HostApplyLayerNorm
(
HostApplyLayerNorm
(
output
->
DATA_PTR
<
output_
t
>
(),
output
->
DATA_PTR
<
scalar_t_ou
t
>
(),
mean
->
DATA_PTR
<
accscalar_
t
>
(),
mean
->
DATA_PTR
<
floa
t
>
(),
invvar
->
DATA_PTR
<
accscalar_
t
>
(),
invvar
->
DATA_PTR
<
floa
t
>
(),
input
->
DATA_PTR
<
scalar_t_
0
>
(),
input
->
DATA_PTR
<
scalar_t_
in
>
(),
n1
,
n2
,
n1
,
n2
,
epsilon
,
epsilon
,
gamma
!=
NULL
?
gamma
->
DATA_PTR
<
output_
t
>
()
:
NULL
,
gamma
!=
NULL
?
gamma
->
DATA_PTR
<
scalar_t_ou
t
>
()
:
NULL
,
beta
!=
NULL
?
beta
->
DATA_PTR
<
output_
t
>
()
:
NULL
);
beta
!=
NULL
?
beta
->
DATA_PTR
<
scalar_t_ou
t
>
()
:
NULL
);
)
)
}
}
template
<
typename
T
,
typename
U
,
typename
V
>
template
<
typename
T
,
typename
U
,
typename
V
>
void
HostLayerNormGradient
(
void
HostLayerNormGradient
(
const
V
*
dout
,
const
V
*
dout
,
...
@@ -742,10 +738,12 @@ void HostLayerNormGradient(
...
@@ -742,10 +738,12 @@ void HostLayerNormGradient(
const
int
part_size
=
16
;
const
int
part_size
=
16
;
const
dim3
threads2
(
32
,
4
,
1
);
const
dim3
threads2
(
32
,
4
,
1
);
const
dim3
blocks2
((
n2
+
threads2
.
x
-
1
)
/
threads2
.
x
,
part_size
,
1
);
const
dim3
blocks2
((
n2
+
threads2
.
x
-
1
)
/
threads2
.
x
,
part_size
,
1
);
const
int
nshared2_a
=
2
*
sizeof
(
U
)
*
threads2
.
y
*
threads2
.
y
*
(
threads2
.
x
+
1
);
const
int
nshared2_a
=
2
*
sizeof
(
U
)
*
threads2
.
y
*
threads2
.
y
*
(
threads2
.
x
+
1
);
const
int
nshared2_b
=
threads2
.
x
*
threads2
.
y
*
sizeof
(
U
);
const
int
nshared2_b
=
threads2
.
x
*
threads2
.
y
*
sizeof
(
U
);
const
int
nshared2
=
nshared2_a
>
nshared2_b
?
nshared2_a
:
nshared2_b
;
const
int
nshared2
=
nshared2_a
>
nshared2_b
?
nshared2_a
:
nshared2_b
;
at
::
Tensor
part_grad_gamma
=
at
::
empty
({
part_size
,
n2
},
input
->
options
().
dtype
(
input
->
scalar_type
()
==
at
::
ScalarType
::
Half
?
at
::
ScalarType
::
Float
:
input
->
scalar_type
()));
at
::
Tensor
part_grad_gamma
=
at
::
empty
(
{
part_size
,
n2
},
input
->
options
().
dtype
(
at
::
ScalarType
::
Float
));
at
::
Tensor
part_grad_beta
=
at
::
empty_like
(
part_grad_gamma
);
at
::
Tensor
part_grad_beta
=
at
::
empty_like
(
part_grad_gamma
);
cuComputePartGradGammaBeta
<<<
blocks2
,
threads2
,
nshared2
,
stream
>>>
(
cuComputePartGradGammaBeta
<<<
blocks2
,
threads2
,
nshared2
,
stream
>>>
(
dout
,
dout
,
...
@@ -770,7 +768,8 @@ void HostLayerNormGradient(
...
@@ -770,7 +768,8 @@ void HostLayerNormGradient(
}
}
// compute grad_input
// compute grad_input
const
uint64_t
maxGridY
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxGridSize
[
1
];
const
uint64_t
maxGridY
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxGridSize
[
1
];
const
dim3
blocks1
(
1
,
std
::
min
((
uint64_t
)
n1
,
maxGridY
),
1
);
const
dim3
blocks1
(
1
,
std
::
min
((
uint64_t
)
n1
,
maxGridY
),
1
);
const
dim3
threads1
(
32
,
4
,
1
);
const
dim3
threads1
(
32
,
4
,
1
);
int
nshared
=
int
nshared
=
...
@@ -788,6 +787,7 @@ void HostLayerNormGradient(
...
@@ -788,6 +787,7 @@ void HostLayerNormGradient(
grad_input
);
grad_input
);
}
}
void
cuda_layer_norm_gradient
(
void
cuda_layer_norm_gradient
(
at
::
Tensor
*
dout
,
at
::
Tensor
*
dout
,
at
::
Tensor
*
mean
,
at
::
Tensor
*
mean
,
...
@@ -808,22 +808,22 @@ void cuda_layer_norm_gradient(
...
@@ -808,22 +808,22 @@ void cuda_layer_norm_gradient(
at
::
Tensor
*
grad_beta
)
at
::
Tensor
*
grad_beta
)
{
{
using
namespace
at
;
using
namespace
at
;
DISPATCH_FLOAT_
AND_HALF
(
input
->
scalar_type
(),
0
,
"cuComputeGradInput"
,
DISPATCH_FLOAT_
HALF_AND_BFLOAT_INOUT_TYPES
(
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
input
->
scalar_type
(),
gamma
->
scalar_type
(),
using
output_t
=
at
::
Half
;
"cuda_layer_norm_gradient_kernel"
,
HostLayerNormGradient
(
HostLayerNormGradient
(
dout
->
DATA_PTR
<
output_
t
>
(),
dout
->
DATA_PTR
<
scalar_t_ou
t
>
(),
mean
->
DATA_PTR
<
accscalar_
t
>
(),
mean
->
DATA_PTR
<
floa
t
>
(),
invvar
->
DATA_PTR
<
accscalar_
t
>
(),
invvar
->
DATA_PTR
<
floa
t
>
(),
input
,
input
,
n1
,
n2
,
n1
,
n2
,
// TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
// TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
// if gamma Tensor is NULL on input.
// if gamma Tensor is NULL on input.
gamma
!=
NULL
?
gamma
->
DATA_PTR
<
output_
t
>
()
:
NULL
,
gamma
!=
NULL
?
gamma
->
DATA_PTR
<
scalar_t_ou
t
>
()
:
NULL
,
gamma
!=
NULL
?
beta
->
DATA_PTR
<
output_
t
>
()
:
NULL
,
gamma
!=
NULL
?
beta
->
DATA_PTR
<
scalar_t_ou
t
>
()
:
NULL
,
epsilon
,
epsilon
,
grad_input
->
DATA_PTR
<
scalar_t_
0
>
(),
grad_input
->
DATA_PTR
<
scalar_t_
in
>
(),
gamma
!=
NULL
?
grad_gamma
->
DATA_PTR
<
output_
t
>
()
:
NULL
,
gamma
!=
NULL
?
grad_gamma
->
DATA_PTR
<
scalar_t_ou
t
>
()
:
NULL
,
gamma
!=
NULL
?
grad_beta
->
DATA_PTR
<
output_
t
>
()
:
NULL
);
gamma
!=
NULL
?
grad_beta
->
DATA_PTR
<
scalar_t_ou
t
>
()
:
NULL
);
)
)
}
}
megatron/fused_kernels/scaled_masked_softmax.cpp
View file @
a5acbf53
...
@@ -37,8 +37,9 @@ torch::Tensor fwd(
...
@@ -37,8 +37,9 @@ torch::Tensor fwd(
torch
::
Tensor
const
&
mask
,
torch
::
Tensor
const
&
mask
,
float
scale_factor
)
{
float
scale_factor
)
{
AT_ASSERTM
(
input
.
dim
()
==
4
,
"expected 4D tensor"
);
AT_ASSERTM
(
input
.
dim
()
==
4
,
"expected 4D tensor"
);
AT_ASSERTM
(
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
,
AT_ASSERTM
((
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
"Only HALF is supported"
);
(
input
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
AT_ASSERTM
(
mask
.
dim
()
==
4
,
"expected 4D tensor"
);
AT_ASSERTM
(
mask
.
dim
()
==
4
,
"expected 4D tensor"
);
return
fwd_cuda
(
input
,
mask
,
scale_factor
);
return
fwd_cuda
(
input
,
mask
,
scale_factor
);
...
@@ -52,10 +53,12 @@ torch::Tensor bwd(
...
@@ -52,10 +53,12 @@ torch::Tensor bwd(
AT_ASSERTM
(
output_grads
.
dim
()
==
4
,
"expected 3D tensor"
);
AT_ASSERTM
(
output_grads
.
dim
()
==
4
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
4
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
4
,
"expected 3D tensor"
);
AT_ASSERTM
(
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
Half
,
AT_ASSERTM
((
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
"Only HALF is supported"
);
(
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
AT_ASSERTM
(
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
Half
,
"Only fp16 and bf16 are supported"
);
"Only HALF is supported"
);
AT_ASSERTM
((
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
return
bwd_cuda
(
output_grads
,
softmax_results
,
scale_factor
);
return
bwd_cuda
(
output_grads
,
softmax_results
,
scale_factor
);
}
}
...
...
megatron/fused_kernels/scaled_masked_softmax.h
View file @
a5acbf53
...
@@ -26,6 +26,27 @@
...
@@ -26,6 +26,27 @@
namespace
{
namespace
{
template
<
typename
Datatype
,
int
ELEMENTS_PER_LDG
>
__device__
__inline__
void
copy_vector
(
Datatype
*
dst
,
const
Datatype
*
src
);
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
BFloat16
,
1
>
(
c10
::
BFloat16
*
dst
,
const
c10
::
BFloat16
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
BFloat16
,
4
>
(
c10
::
BFloat16
*
dst
,
const
c10
::
BFloat16
*
src
)
{
*
((
float2
*
)
dst
)
=
*
((
float2
*
)
src
);
}
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
Half
,
1
>
(
c10
::
Half
*
dst
,
const
c10
::
Half
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
Half
,
4
>
(
c10
::
Half
*
dst
,
const
c10
::
Half
*
src
)
{
*
((
float2
*
)
dst
)
=
*
((
float2
*
)
src
);
}
template
<
>
__device__
__inline__
void
copy_vector
<
uint8_t
,
1
>
(
uint8_t
*
dst
,
const
uint8_t
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
uint8_t
,
4
>
(
uint8_t
*
dst
,
const
uint8_t
*
src
)
{
*
((
half2
*
)
dst
)
=
*
((
half2
*
)
src
);
}
int
log2_ceil
(
int
value
)
{
int
log2_ceil
(
int
value
)
{
int
log2_value
=
0
;
int
log2_value
=
0
;
while
((
1
<<
log2_value
)
<
value
)
++
log2_value
;
while
((
1
<<
log2_value
)
<
value
)
++
log2_value
;
...
@@ -90,13 +111,14 @@ __global__ void scaled_masked_softmax_warp_forward(
...
@@ -90,13 +111,14 @@ __global__ void scaled_masked_softmax_warp_forward(
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
ELEMENTS_PER_LDG_STG
=
4
;
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
// gridDim/blockIdx = (seq_len, attn_heads, batches)
int
first_batch
=
(
blockDim
.
y
*
(
blockIdx
.
x
+
gridDim
.
x
*
(
blockIdx
.
y
+
gridDim
.
y
*
blockIdx
.
z
))
+
threadIdx
.
y
)
*
WARP_BATCH
;
int
first_batch
=
(
blockDim
.
y
*
(
blockIdx
.
x
+
gridDim
.
x
*
(
blockIdx
.
y
+
gridDim
.
y
*
blockIdx
.
z
))
+
threadIdx
.
y
)
*
WARP_BATCH
;
int
pad_first_batch
=
0
;
int
pad_first_batch
=
0
;
if
(
pad_batches
!=
1
)
{
// bert style
if
(
pad_batches
!=
1
)
{
// bert style
pad_first_batch
=
(
blockDim
.
y
*
(
blockIdx
.
x
+
gridDim
.
x
*
blockIdx
.
z
)
+
threadIdx
.
y
)
*
WARP_BATCH
;
pad_first_batch
=
(
blockDim
.
y
*
(
blockIdx
.
x
+
gridDim
.
x
*
blockIdx
.
z
)
+
threadIdx
.
y
)
*
WARP_BATCH
;
}
else
{
// gpt2 style
}
else
{
// gpt2 style
pad_first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
WARP_BATCH
;
pad_first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
WARP_BATCH
;
}
}
...
@@ -110,29 +132,40 @@ __global__ void scaled_masked_softmax_warp_forward(
...
@@ -110,29 +132,40 @@ __global__ void scaled_masked_softmax_warp_forward(
// there might be multiple batches per warp. compute the index within the batch
// there might be multiple batches per warp. compute the index within the batch
int
local_idx
=
threadIdx
.
x
;
int
local_idx
=
threadIdx
.
x
;
src
+=
first_batch
*
element_count
+
local_idx
;
src
+=
first_batch
*
element_count
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
dst
+=
first_batch
*
element_count
+
local_idx
;
dst
+=
first_batch
*
element_count
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
mask
+=
pad_first_batch
*
element_count
+
local_idx
;
mask
+=
pad_first_batch
*
element_count
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
// load data from global memory
// load data from global memory
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
input_t
temp_data
[
ELEMENTS_PER_LDG_STG
];
uint8_t
temp_mask
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
#pragma unroll
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
int
itr_idx
=
i
*
element_count
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
if
(
element_index
<
batch_element_count
)
{
if
(
mask
[
itr_idx
]
!=
1
)
{
int
itr_idx
=
i
*
element_count
+
it
*
WARP_SIZE
;
elements
[
i
][
it
]
=
(
acc_t
)
src
[
itr_idx
]
*
scale
;
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_data
,
src
+
itr_idx
);
}
else
{
copy_vector
<
uint8_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_mask
,
mask
+
itr_idx
);
elements
[
i
][
it
]
=
-
10000.0
;
}
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
if
(
temp_mask
[
element
]
!=
1
)
{
elements
[
i
][
it
+
element
]
=
(
acc_t
)
temp_data
[
element
]
*
scale
;
}
else
{
elements
[
i
][
it
+
element
]
=
-
10000.0
;
}
}
}
else
{
}
else
{
elements
[
i
][
it
]
=
-
std
::
numeric_limits
<
acc_t
>::
infinity
();
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
elements
[
i
][
it
+
element
]
=
-
std
::
numeric_limits
<
acc_t
>::
infinity
();
}
}
}
}
}
}
}
...
@@ -161,15 +194,20 @@ __global__ void scaled_masked_softmax_warp_forward(
...
@@ -161,15 +194,20 @@ __global__ void scaled_masked_softmax_warp_forward(
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Add
>
(
sum
);
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Add
>
(
sum
);
// store result
// store result
output_t
out
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
if
(
i
>=
local_batches
)
break
;
break
;
#pragma unroll
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
if
(
element_index
<
element_count
)
{
dst
[
i
*
element_count
+
it
*
WARP_SIZE
]
=
(
output_t
)(
elements
[
i
][
it
]
/
sum
[
i
]);
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
out
[
element
]
=
elements
[
i
][
it
+
element
]
/
sum
[
i
];
}
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
dst
+
i
*
element_count
+
it
*
WARP_SIZE
,
out
);
}
else
{
}
else
{
break
;
break
;
}
}
...
@@ -192,6 +230,7 @@ __global__ void scaled_masked_softmax_warp_backward(
...
@@ -192,6 +230,7 @@ __global__ void scaled_masked_softmax_warp_backward(
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
ELEMENTS_PER_LDG_STG
=
4
;
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
// gridDim/blockIdx = (seq_len, attn_heads, batches)
...
@@ -207,36 +246,36 @@ __global__ void scaled_masked_softmax_warp_backward(
...
@@ -207,36 +246,36 @@ __global__ void scaled_masked_softmax_warp_backward(
int
local_idx
=
threadIdx
.
x
;
int
local_idx
=
threadIdx
.
x
;
// the first element to process by the current thread
// the first element to process by the current thread
int
thread_offset
=
first_batch
*
element_count
+
local_idx
;
int
thread_offset
=
first_batch
*
element_count
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
grad
+=
thread_offset
;
grad
+=
thread_offset
;
output
+=
thread_offset
;
output
+=
thread_offset
;
gradInput
+=
thread_offset
;
gradInput
+=
thread_offset
;
// load data from global memory
// load data from global memory
acc_t
grad_reg
[
WARP_BATCH
][
WARP_ITERATIONS
]
{
0.0
f
};
acc_t
grad_reg
[
WARP_BATCH
][
WARP_ITERATIONS
]
{
0.0
f
};
acc_t
output_reg
[
WARP_BATCH
][
WARP_ITERATIONS
];
acc_t
output_reg
[
WARP_BATCH
][
WARP_ITERATIONS
]
{
0.0
f
};
input_t
temp_grad
[
ELEMENTS_PER_LDG_STG
];
input_t
temp_output
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
#pragma unroll
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
if
(
element_index
<
batch_element_count
)
{
output_reg
[
i
][
it
]
=
output
[
i
*
element_count
+
it
*
WARP_SIZE
];
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_grad
,
grad
+
i
*
element_count
+
it
*
WARP_SIZE
);
}
else
{
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_output
,
output
+
i
*
element_count
+
it
*
WARP_SIZE
);
output_reg
[
i
][
it
]
=
acc_t
(
0
);
}
#pragma unroll
}
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
output_reg
[
i
][
it
+
element
]
=
(
acc_t
)
temp_output
[
element
];
#pragma unroll
}
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
#pragma unroll
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
if
(
element_index
<
batch_element_count
)
{
grad_reg
[
i
][
it
+
element
]
=
(
acc_t
)
temp_grad
[
element
]
*
output_reg
[
i
][
it
+
element
];
grad_reg
[
i
][
it
]
=
(
acc_t
)
grad
[
i
*
element_count
+
it
*
WARP_SIZE
]
*
output_reg
[
i
][
it
];
}
}
else
{
}
grad_reg
[
i
][
it
]
=
acc_t
(
0
);
}
}
}
}
}
...
@@ -257,11 +296,16 @@ __global__ void scaled_masked_softmax_warp_backward(
...
@@ -257,11 +296,16 @@ __global__ void scaled_masked_softmax_warp_backward(
if
(
i
>=
local_batches
)
if
(
i
>=
local_batches
)
break
;
break
;
#pragma unroll
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
if
(
element_index
<
element_count
)
{
// compute gradients
// compute gradients
gradInput
[
i
*
element_count
+
it
*
WARP_SIZE
]
=
(
output_t
)(
scale
*
(
grad_reg
[
i
][
it
]
-
output_reg
[
i
][
it
]
*
sum
[
i
]));
output_t
out
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
out
[
element
]
=
(
output_t
)(
scale
*
(
grad_reg
[
i
][
it
+
element
]
-
output_reg
[
i
][
it
+
element
]
*
sum
[
i
]));
}
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
gradInput
+
i
*
element_count
+
it
*
WARP_SIZE
,
out
);
}
}
}
}
}
}
...
@@ -299,8 +343,8 @@ void dispatch_scaled_masked_softmax_forward(
...
@@ -299,8 +343,8 @@ void dispatch_scaled_masked_softmax_forward(
constexpr
int
threads_per_block
=
128
;
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
TORCH_INTERNAL_ASSERT
(
query_seq_len
%
batches_per_block
==
0
);
TORCH_INTERNAL_ASSERT
(
query_seq_len
%
batches_per_block
==
0
);
dim3
blocks
(
query_seq_len
/
batches_per_block
,
attn_heads
,
batches
);
dim3
blocks
(
query_seq_len
/
batches_per_block
,
attn_heads
,
batches
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
...
@@ -388,7 +432,7 @@ void dispatch_scaled_masked_softmax_backward(
...
@@ -388,7 +432,7 @@ void dispatch_scaled_masked_softmax_backward(
constexpr
int
threads_per_block
=
128
;
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
blocks
=
batch_count
/
batches_per_block
;
int
blocks
=
batch_count
/
batches_per_block
;
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
...
...
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
View file @
a5acbf53
...
@@ -19,10 +19,10 @@
...
@@ -19,10 +19,10 @@
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <torch/extension.h>
#include "scaled_masked_softmax.h"
#include "scaled_masked_softmax.h"
#include "type_shim.h"
namespace
multihead_attn
{
namespace
multihead_attn
{
namespace
fused_softmax
{
namespace
fused_softmax
{
...
@@ -56,16 +56,20 @@ torch::Tensor fwd_cuda(
...
@@ -56,16 +56,20 @@ torch::Tensor fwd_cuda(
void
*
mask_ptr
=
static_cast
<
void
*>
(
mask
.
data_ptr
());
void
*
mask_ptr
=
static_cast
<
void
*>
(
mask
.
data_ptr
());
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
dispatch_scaled_masked_softmax_forward
<
half
,
half
,
float
>
(
DISPATCH_HALF_AND_BFLOAT
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
input
.
scalar_type
(),
reinterpret_cast
<
const
half
*>
(
input_ptr
),
"dispatch_scaled_masked_softmax_forward"
,
reinterpret_cast
<
const
uint8_t
*>
(
mask_ptr
),
dispatch_scaled_masked_softmax_forward
<
scalar_t
,
scalar_t
,
float
>
(
scale_factor
,
reinterpret_cast
<
scalar_t
*>
(
softmax_results_ptr
),
query_seq_len
,
reinterpret_cast
<
const
scalar_t
*>
(
input_ptr
),
key_seq_len
,
reinterpret_cast
<
const
uint8_t
*>
(
mask_ptr
),
batches
,
scale_factor
,
attn_heads
,
query_seq_len
,
pad_batches
);
key_seq_len
,
batches
,
attn_heads
,
pad_batches
);
);
return
softmax_results
;
return
softmax_results
;
}
}
...
@@ -86,15 +90,19 @@ torch::Tensor bwd_cuda(
...
@@ -86,15 +90,19 @@ torch::Tensor bwd_cuda(
void
*
output_grads_ptr
=
static_cast
<
void
*>
(
output_grads
.
data_ptr
());
void
*
output_grads_ptr
=
static_cast
<
void
*>
(
output_grads
.
data_ptr
());
//Softmax Grad
//Softmax Grad
dispatch_scaled_masked_softmax_backward
<
half
,
half
,
float
>
(
DISPATCH_HALF_AND_BFLOAT
(
reinterpret_cast
<
half
*>
(
output_grads_ptr
),
output_grads_
.
scalar_type
(),
reinterpret_cast
<
half
*>
(
output_grads_ptr
),
"dispatch_scaled_masked_softmax_backward"
,
reinterpret_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
dispatch_scaled_masked_softmax_backward
<
scalar_t
,
scalar_t
,
float
>
(
scale_factor
,
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
query_seq_len
,
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
key_seq_len
,
reinterpret_cast
<
scalar_t
const
*>
(
softmax_results
.
data_ptr
()),
batches
,
scale_factor
,
attn_heads
);
query_seq_len
,
key_seq_len
,
batches
,
attn_heads
);
);
//backward pass is completely in-place
//backward pass is completely in-place
return
output_grads
;
return
output_grads
;
...
...
megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp
View file @
a5acbf53
...
@@ -33,8 +33,9 @@ torch::Tensor bwd_cuda(
...
@@ -33,8 +33,9 @@ torch::Tensor bwd_cuda(
torch
::
Tensor
fwd
(
torch
::
Tensor
const
&
input
,
float
scale_factor
)
{
torch
::
Tensor
fwd
(
torch
::
Tensor
const
&
input
,
float
scale_factor
)
{
AT_ASSERTM
(
input
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
input
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
,
AT_ASSERTM
((
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
"Only HALF is supported"
);
(
input
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
return
fwd_cuda
(
input
,
scale_factor
);
return
fwd_cuda
(
input
,
scale_factor
);
}
}
...
@@ -47,10 +48,12 @@ torch::Tensor bwd(
...
@@ -47,10 +48,12 @@ torch::Tensor bwd(
AT_ASSERTM
(
output_grads
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
output_grads
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
Half
,
AT_ASSERTM
((
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
"Only HALF is supported"
);
(
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
AT_ASSERTM
(
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
Half
,
"Only fp16 and bf16 are supported"
);
"Only HALF is supported"
);
AT_ASSERTM
((
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
return
bwd_cuda
(
output_grads
,
softmax_results
,
scale_factor
);
return
bwd_cuda
(
output_grads
,
softmax_results
,
scale_factor
);
}
}
...
@@ -61,7 +64,7 @@ torch::Tensor bwd(
...
@@ -61,7 +64,7 @@ torch::Tensor bwd(
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
m
.
def
(
"forward"
,
&
multihead_attn
::
fused_softmax
::
scaled_upper_triang_masked_softmax
::
fwd
,
&
multihead_attn
::
fused_softmax
::
scaled_upper_triang_masked_softmax
::
fwd
,
"Self Multihead Attention scaled, time masked softmax -- Forward."
);
"Self Multihead Attention scaled, time masked softmax -- Forward."
);
m
.
def
(
"backward"
,
m
.
def
(
"backward"
,
&
multihead_attn
::
fused_softmax
::
scaled_upper_triang_masked_softmax
::
bwd
,
&
multihead_attn
::
fused_softmax
::
scaled_upper_triang_masked_softmax
::
bwd
,
...
...
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
View file @
a5acbf53
...
@@ -21,11 +21,47 @@
...
@@ -21,11 +21,47 @@
#include <cfloat>
#include <cfloat>
#include <limits>
#include <limits>
#include <stdint.h>
#include <stdint.h>
#include <cuda_fp16.h>
#include <c10/macros/Macros.h>
#include <c10/macros/Macros.h>
namespace
{
namespace
{
template
<
typename
Datatype
,
int
ELEMENTS_PER_LDG
>
__device__
__inline__
void
copy_vector
(
Datatype
*
dst
,
const
Datatype
*
src
);
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
BFloat16
,
1
>
(
c10
::
BFloat16
*
dst
,
const
c10
::
BFloat16
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
BFloat16
,
4
>
(
c10
::
BFloat16
*
dst
,
const
c10
::
BFloat16
*
src
)
{
*
((
float2
*
)
dst
)
=
*
((
float2
*
)
src
);
}
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
Half
,
1
>
(
c10
::
Half
*
dst
,
const
c10
::
Half
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
Half
,
4
>
(
c10
::
Half
*
dst
,
const
c10
::
Half
*
src
)
{
*
((
float2
*
)
dst
)
=
*
((
float2
*
)
src
);
}
template
<
>
__device__
__inline__
void
copy_vector
<
uint8_t
,
1
>
(
uint8_t
*
dst
,
const
uint8_t
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
uint8_t
,
4
>
(
uint8_t
*
dst
,
const
uint8_t
*
src
)
{
*
((
half2
*
)
dst
)
=
*
((
half2
*
)
src
);
}
template
<
typename
Datatype
,
int
ELEMENTS_PER_LDG
>
__device__
__inline__
void
copy_zero_vector
(
Datatype
*
dst
);
template
<
>
__device__
__inline__
void
copy_zero_vector
<
c10
::
BFloat16
,
1
>
(
c10
::
BFloat16
*
dst
)
{
*
dst
=
0.0
;
}
template
<
>
__device__
__inline__
void
copy_zero_vector
<
c10
::
BFloat16
,
4
>
(
c10
::
BFloat16
*
dst
)
{
*
((
float2
*
)
dst
)
=
make_float2
(
0.0
f
,
0.0
f
);
}
template
<
>
__device__
__inline__
void
copy_zero_vector
<
c10
::
Half
,
1
>
(
c10
::
Half
*
dst
)
{
*
dst
=
0.0
;
}
template
<
>
__device__
__inline__
void
copy_zero_vector
<
c10
::
Half
,
4
>
(
c10
::
Half
*
dst
)
{
*
((
float2
*
)
dst
)
=
make_float2
(
0.0
f
,
0.0
f
);
}
int
log2_ceil
(
int
value
)
{
int
log2_ceil
(
int
value
)
{
int
log2_value
=
0
;
int
log2_value
=
0
;
while
((
1
<<
log2_value
)
<
value
)
++
log2_value
;
while
((
1
<<
log2_value
)
<
value
)
++
log2_value
;
...
@@ -73,7 +109,7 @@ __device__ __forceinline__ void warp_reduce(acc_t* sum) {
...
@@ -73,7 +109,7 @@ __device__ __forceinline__ void warp_reduce(acc_t* sum) {
* Extended softmax (from native aten pytorch) with following additional features
* Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling
* 1) input scaling
* 2) Implicit time (diagonal masking)
* 2) Implicit time (diagonal masking)
*/
*/
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
log2_elements
>
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
log2_elements
>
__global__
void
scaled_upper_triang_masked_softmax_warp_forward
(
__global__
void
scaled_upper_triang_masked_softmax_warp_forward
(
output_t
*
dst
,
output_t
*
dst
,
...
@@ -89,10 +125,11 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
...
@@ -89,10 +125,11 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
ELEMENTS_PER_LDG_STG
=
4
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
y
+
threadIdx
.
y
)
*
gridDim
.
x
*
WARP_BATCH
+
blockIdx
.
x
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
y
+
threadIdx
.
y
)
*
gridDim
.
x
*
WARP_BATCH
+
blockIdx
.
x
;
int
local_seq
=
blockIdx
.
x
+
1
;
int
local_seq
=
blockIdx
.
x
+
1
;
int
warp_iteration_limit
=
(
local_seq
+
WARP_SIZE
-
1
)
/
WARP_SIZE
;
int
warp_iteration_limit
=
(
local_seq
+
ELEMENTS_PER_LDG_STG
*
WARP_SIZE
-
1
)
/
WARP_SIZE
;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
// many batches have to computed within this WARP.
...
@@ -103,22 +140,36 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
...
@@ -103,22 +140,36 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
// there might be multiple batches per warp. compute the index within the batch
// there might be multiple batches per warp. compute the index within the batch
int
local_idx
=
threadIdx
.
x
;
int
local_idx
=
threadIdx
.
x
;
src
+=
first_batch
*
stride
+
local_idx
;
src
+=
first_batch
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
dst
+=
first_batch
*
stride
+
local_idx
;
dst
+=
first_batch
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
// load data from global memory
// load data from global memory
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
input_t
temp_data
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
local_seq
;
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
local_seq
;
#pragma unroll
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
if
(
element_index
<
batch_element_count
)
{
elements
[
i
][
it
]
=
(
acc_t
)
src
[
i
*
element_count
*
stride
+
it
*
WARP_SIZE
]
*
scale
;
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_data
,
src
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
);
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
if
((
element_index
+
element
)
<
batch_element_count
)
{
elements
[
i
][
it
+
element
]
=
(
acc_t
)
temp_data
[
element
]
*
scale
;
}
else
{
elements
[
i
][
it
+
element
]
=
-
std
::
numeric_limits
<
acc_t
>::
infinity
();
}
}
}
else
{
}
else
{
elements
[
i
][
it
]
=
-
std
::
numeric_limits
<
acc_t
>::
infinity
();
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
elements
[
i
][
it
+
element
]
=
-
std
::
numeric_limits
<
acc_t
>::
infinity
();
}
}
}
}
}
}
}
...
@@ -140,26 +191,37 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
...
@@ -140,26 +191,37 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
#pragma unroll
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
if
(
it
<
warp_iteration_limit
)
{
if
(
it
<
warp_iteration_limit
)
{
elements
[
i
][
it
]
=
std
::
exp
((
elements
[
i
][
it
]
-
max_value
[
i
]));
elements
[
i
][
it
]
=
std
::
exp
((
elements
[
i
][
it
]
-
max_value
[
i
]));
sum
[
i
]
+=
elements
[
i
][
it
];
sum
[
i
]
+=
elements
[
i
][
it
];
}
}
}
}
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Add
>
(
sum
);
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Add
>
(
sum
);
// store result
// store result
output_t
out
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
if
(
i
>=
local_batches
)
break
;
break
;
#pragma unroll
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
local_seq
)
{
if
(
element_index
<
local_seq
)
{
dst
[
i
*
element_count
*
stride
+
it
*
WARP_SIZE
]
=
(
output_t
)(
elements
[
i
][
it
]
/
sum
[
i
]);
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
if
(
element_index
+
element
<
local_seq
)
{
out
[
element
]
=
elements
[
i
][
it
+
element
]
/
sum
[
i
];
}
else
{
out
[
element
]
=
0
;
}
}
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
dst
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
,
out
);
}
else
if
(
element_index
<
element_count
)
{
}
else
if
(
element_index
<
element_count
)
{
dst
[
i
*
element_count
*
stride
+
it
*
WARP_SIZE
]
=
0
;
copy_zero_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
dst
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
)
;
}
else
{
}
else
{
break
;
break
;
}
}
...
@@ -183,6 +245,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
...
@@ -183,6 +245,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
ELEMENTS_PER_LDG_STG
=
4
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
y
+
threadIdx
.
y
)
*
gridDim
.
x
*
WARP_BATCH
+
blockIdx
.
x
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
y
+
threadIdx
.
y
)
*
gridDim
.
x
*
WARP_BATCH
+
blockIdx
.
x
;
int
local_seq
=
blockIdx
.
x
+
1
;
int
local_seq
=
blockIdx
.
x
+
1
;
...
@@ -197,37 +260,41 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
...
@@ -197,37 +260,41 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
int
local_idx
=
threadIdx
.
x
;
int
local_idx
=
threadIdx
.
x
;
// the first element to process by the current thread
// the first element to process by the current thread
int
thread_offset
=
first_batch
*
stride
+
local_idx
;
int
thread_offset
=
first_batch
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
grad
+=
thread_offset
;
grad
+=
thread_offset
;
output
+=
thread_offset
;
output
+=
thread_offset
;
gradInput
+=
thread_offset
;
gradInput
+=
thread_offset
;
// load data from global memory
// load data from global memory
acc_t
grad_reg
[
WARP_BATCH
][
WARP_ITERATIONS
]
{
0.0
f
};
acc_t
grad_reg
[
WARP_BATCH
][
WARP_ITERATIONS
]
{
0.0
f
};
acc_t
output_reg
[
WARP_BATCH
][
WARP_ITERATIONS
];
acc_t
output_reg
[
WARP_BATCH
][
WARP_ITERATIONS
]
{
0.0
f
};
input_t
temp_grad
[
ELEMENTS_PER_LDG_STG
];
input_t
temp_output
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
local_seq
;
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
local_seq
;
#pragma unroll
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
if
(
element_index
<
batch_element_count
)
{
output_reg
[
i
][
it
]
=
output
[
i
*
element_count
*
stride
+
it
*
WARP_SIZE
];
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_grad
,
grad
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
);
}
else
{
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_output
,
output
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
);
output_reg
[
i
][
it
]
=
acc_t
(
0
);
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
if
(
element_index
+
element
<
batch_element_count
)
{
output_reg
[
i
][
it
+
element
]
=
(
acc_t
)
temp_output
[
element
];
}
}
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
if
(
element_index
+
element
<
batch_element_count
)
{
grad_reg
[
i
][
it
+
element
]
=
(
acc_t
)
temp_grad
[
element
]
*
output_reg
[
i
][
it
+
element
];
}
}
}
}
}
}
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
grad_reg
[
i
][
it
]
=
(
acc_t
)
grad
[
i
*
element_count
*
stride
+
it
*
WARP_SIZE
]
*
output_reg
[
i
][
it
];
}
else
{
grad_reg
[
i
][
it
]
=
acc_t
(
0
);
}
}
}
}
acc_t
sum
[
WARP_BATCH
];
acc_t
sum
[
WARP_BATCH
];
...
@@ -247,11 +314,16 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
...
@@ -247,11 +314,16 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
if
(
i
>=
local_batches
)
if
(
i
>=
local_batches
)
break
;
break
;
#pragma unroll
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
if
(
element_index
<
element_count
)
{
// compute gradients
// compute gradients
gradInput
[
i
*
element_count
*
stride
+
it
*
WARP_SIZE
]
=
(
output_t
)(
scale
*
(
grad_reg
[
i
][
it
]
-
output_reg
[
i
][
it
]
*
sum
[
i
]));
output_t
out
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
out
[
element
]
=
(
output_t
)(
scale
*
(
grad_reg
[
i
][
it
+
element
]
-
output_reg
[
i
][
it
+
element
]
*
sum
[
i
]));
}
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
gradInput
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
,
out
);
}
}
}
}
}
}
...
...
megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu
View file @
a5acbf53
...
@@ -19,10 +19,10 @@
...
@@ -19,10 +19,10 @@
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <torch/extension.h>
#include "scaled_upper_triang_masked_softmax.h"
#include "scaled_upper_triang_masked_softmax.h"
#include "type_shim.h"
namespace
multihead_attn
{
namespace
multihead_attn
{
namespace
fused_softmax
{
namespace
fused_softmax
{
...
@@ -46,15 +46,20 @@ torch::Tensor fwd_cuda(
...
@@ -46,15 +46,20 @@ torch::Tensor fwd_cuda(
void
*
input_ptr
=
static_cast
<
void
*>
(
input
.
data_ptr
());
void
*
input_ptr
=
static_cast
<
void
*>
(
input
.
data_ptr
());
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
dispatch_scaled_upper_triang_masked_softmax_forward
<
half
,
half
,
float
>
(
DISPATCH_HALF_AND_BFLOAT
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
input
.
scalar_type
(),
reinterpret_cast
<
const
half
*>
(
input_ptr
),
"dispatch_scaled_upper_triang_masked_softmax_forward"
,
scale_factor
,
dispatch_scaled_upper_triang_masked_softmax_forward
<
scalar_t
,
scalar_t
,
float
>
(
seq_len
,
reinterpret_cast
<
scalar_t
*>
(
softmax_results_ptr
),
seq_len
,
reinterpret_cast
<
const
scalar_t
*>
(
input_ptr
),
attn_batches
);
scale_factor
,
seq_len
,
seq_len
,
attn_batches
);
);
return
softmax_results
;
return
softmax_results
;
}
}
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
const
&
output_grads_
,
torch
::
Tensor
const
&
output_grads_
,
...
@@ -72,14 +77,18 @@ torch::Tensor bwd_cuda(
...
@@ -72,14 +77,18 @@ torch::Tensor bwd_cuda(
void
*
output_grads_ptr
=
static_cast
<
void
*>
(
output_grads
.
data_ptr
());
void
*
output_grads_ptr
=
static_cast
<
void
*>
(
output_grads
.
data_ptr
());
//Softmax Grad
//Softmax Grad
dispatch_scaled_upper_triang_masked_softmax_backward
<
half
,
half
,
float
>
(
DISPATCH_HALF_AND_BFLOAT
(
reinterpret_cast
<
half
*>
(
output_grads_ptr
),
output_grads_
.
scalar_type
(),
reinterpret_cast
<
half
*>
(
output_grads_ptr
),
"dispatch_scaled_upper_triang_masked_softmax_backward"
,
reinterpret_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
dispatch_scaled_upper_triang_masked_softmax_backward
<
scalar_t
,
scalar_t
,
float
>
(
scale_factor
,
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
seq_len
,
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
seq_len
,
reinterpret_cast
<
scalar_t
const
*>
(
softmax_results
.
data_ptr
()),
attn_batches
);
scale_factor
,
seq_len
,
seq_len
,
attn_batches
);
);
//backward pass is completely in-place
//backward pass is completely in-place
return
output_grads
;
return
output_grads
;
...
...
megatron/fused_kernels/type_shim.h
View file @
a5acbf53
...
@@ -14,214 +14,78 @@
...
@@ -14,214 +14,78 @@
* limitations under the License.
* limitations under the License.
*/
*/
/*This code is copied fron NVIDIA apex:
* https://github.com/NVIDIA/apex
* with minor changes. */
#include <ATen/ATen.h>
#include <ATen/ATen.h>
#include "compat.h"
#include "compat.h"
// Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream.
// struct TypeShim
// {
// const at::Type& payload;
// TypeShim(const at::Type& type) : payload(type) {}
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
// operator const at::Type&(){ return payload; };
// // Enable dispatch switch statements to take *this directly for post-3aeb78
// //operator at::ScalarType(){ return payload.; };
// };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
switch(TYPE) \
switch(TYPE) \
{ \
{ \
case at::ScalarType::Float: \
case at::ScalarType::Half: \
{ \
{ \
using scalar_t_##LEVEL = float; \
using scalar_t = at::Half; \
__VA_ARGS__; \
__VA_ARGS__; \
break; \
break; \
} \
} \
case at::ScalarType::Half: \
case at::ScalarType::BFloat16: \
{ \
{ \
using scalar_t_##LEVEL = at::Half; \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
__VA_ARGS__; \
break; \
break; \
} \
} \
default: \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch(TYPEIN) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_in = float; \
switch(TYPEOUT) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
} \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_in = at::Half; \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_in = at::BFloat16; \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Byte: \
{ \
using scalar_t_##LEVEL = uint8_t; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template
<
typename
T
>
__device__
__forceinline__
T
reduce_block_into_lanes
(
T
*
x
,
T
val
,
int
lanes
=
1
,
bool
share_result
=
false
)
// lanes is intended to be <= 32.
{
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
int
blockSize
=
blockDim
.
x
*
blockDim
.
y
;
// blockSize is intended to be a multiple of 32.
if
(
blockSize
>=
64
)
{
x
[
tid
]
=
val
;
__syncthreads
();
}
#pragma unroll
for
(
int
i
=
(
blockSize
>>
1
);
i
>=
64
;
i
>>=
1
)
{
if
(
tid
<
i
)
x
[
tid
]
=
x
[
tid
]
+
x
[
tid
+
i
];
__syncthreads
();
}
T
final
;
if
(
tid
<
32
)
{
if
(
blockSize
>=
64
)
final
=
x
[
tid
]
+
x
[
tid
+
32
];
else
final
=
val
;
// __SYNCWARP();
#pragma unroll
for
(
int
i
=
16
;
i
>=
lanes
;
i
>>=
1
)
final
=
final
+
__shfl_down_sync
(
0xffffffff
,
final
,
i
);
}
if
(
share_result
)
{
if
(
tid
<
lanes
)
x
[
tid
]
=
final
;
// EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads
();
}
return
final
;
}
template
<
typename
T
>
__device__
__forceinline__
T
reduce_block_into_lanes_max_op
(
T
*
x
,
T
val
,
int
lanes
=
1
,
bool
share_result
=
false
)
// lanes is intended to be <= 32.
{
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
int
blockSize
=
blockDim
.
x
*
blockDim
.
y
;
// blockSize is intended to be a multiple of 32.
if
(
blockSize
>=
64
)
{
x
[
tid
]
=
val
;
__syncthreads
();
}
#pragma unroll
for
(
int
i
=
(
blockSize
>>
1
);
i
>=
64
;
i
>>=
1
)
{
if
(
tid
<
i
)
x
[
tid
]
=
fmaxf
(
fabsf
(
x
[
tid
]),
fabsf
(
x
[
tid
+
i
]));
__syncthreads
();
}
T
final
;
if
(
tid
<
32
)
{
if
(
blockSize
>=
64
)
final
=
fmaxf
(
fabsf
(
x
[
tid
]),
fabsf
(
x
[
tid
+
32
]));
else
final
=
val
;
// __SYNCWARP();
#pragma unroll
for
(
int
i
=
16
;
i
>=
lanes
;
i
>>=
1
)
final
=
fmaxf
(
fabsf
(
final
),
fabsf
(
__shfl_down_sync
(
0xffffffff
,
final
,
i
)));
}
if
(
share_result
)
{
if
(
tid
<
lanes
)
x
[
tid
]
=
final
;
// EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads
();
}
return
final
;
}
megatron/initialize.py
View file @
a5acbf53
...
@@ -17,16 +17,20 @@
...
@@ -17,16 +17,20 @@
import
random
import
random
import
os
import
os
import
time
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
megatron
import
fused_kernels
from
megatron
import
get_adlr_autoresume
from
megatron
import
get_adlr_autoresume
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
get_tensorboard_writer
from
megatron
import
get_tensorboard_writer
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.global_vars
import
set_global_variables
from
megatron.global_vars
import
set_global_variables
from
megatron.mpu
import
set_tensor_model_parallel_rank
,
set_tensor_model_parallel_world_size
from
megatron.mpu
import
(
set_tensor_model_parallel_rank
,
set_tensor_model_parallel_world_size
)
def
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{},
def
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{},
ignore_unknown_args
=
False
,
allow_no_cuda
=
False
):
ignore_unknown_args
=
False
,
allow_no_cuda
=
False
):
...
@@ -37,8 +41,7 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
...
@@ -37,8 +41,7 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
what you are doing.
what you are doing.
Returns a function to finalize distributed env initialization
Returns a function to finalize distributed env initialization
(optionally, only when args.lazy_mpu_init == True)
(optionally, only when args.lazy_mpu_init == True)
"""
"""
if
not
allow_no_cuda
:
if
not
allow_no_cuda
:
# Make sure cuda is available.
# Make sure cuda is available.
assert
torch
.
cuda
.
is_available
(),
'Megatron requires CUDA.'
assert
torch
.
cuda
.
is_available
(),
'Megatron requires CUDA.'
...
@@ -66,7 +69,8 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
...
@@ -66,7 +69,8 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
# delayed initialization of DDP-related stuff
# delayed initialization of DDP-related stuff
# We only set basic DDP globals
# We only set basic DDP globals
set_tensor_model_parallel_world_size
(
args
.
tensor_model_parallel_size
)
set_tensor_model_parallel_world_size
(
args
.
tensor_model_parallel_size
)
# and return function for external DDP manager to call when it has DDP initialized
# and return function for external DDP manager
# to call when it has DDP initialized
set_tensor_model_parallel_rank
(
args
.
rank
)
set_tensor_model_parallel_rank
(
args
.
rank
)
return
finish_mpu_init
return
finish_mpu_init
else
:
else
:
...
@@ -79,16 +83,71 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
...
@@ -79,16 +83,71 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
# Autoresume.
# Autoresume.
_init_autoresume
()
_init_autoresume
()
# Compile dataset C++ code.
# Compile dependencies.
if
torch
.
distributed
.
get_rank
()
==
0
:
_compile_dependencies
()
from
megatron.data.dataset_utils
import
compile_helper
compile_helper
()
# Simple barrier
torch
.
distributed
.
barrier
()
# No continuation function
# No continuation function
return
None
return
None
def
_compile_dependencies
():
args
=
get_args
()
# =========================
# Compile dataset C++ code.
# =========================
# TODO: move this to ninja
if
torch
.
distributed
.
get_rank
()
==
0
:
start_time
=
time
.
time
()
print
(
'> compiling dataset index builder ...'
)
from
megatron.data.dataset_utils
import
compile_helper
compile_helper
()
print
(
'>>> done with dataset index builder. Compilation time: {:.3f} '
'seconds'
.
format
(
time
.
time
()
-
start_time
),
flush
=
True
)
# ==================
# Load fused kernels
# ==================
# Custom kernel constraints check.
seq_len
=
args
.
seq_length
attn_batch_size
=
\
(
args
.
num_attention_heads
/
args
.
tensor_model_parallel_size
)
*
\
args
.
micro_batch_size
# Constraints on sequence length and attn_batch_size to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint
=
seq_len
>
16
and
seq_len
<=
2048
and
\
seq_len
%
4
==
0
and
attn_batch_size
%
4
==
0
# Print a warning.
if
not
((
args
.
fp16
or
args
.
bf16
)
and
custom_kernel_constraint
and
args
.
masked_softmax_fusion
):
if
args
.
rank
==
0
:
print
(
'WARNING: constraints for invoking optimized'
' fused softmax kernel are not met. We default'
' back to unfused kernel invocations.'
,
flush
=
True
)
# Always build on rank zero first.
if
torch
.
distributed
.
get_rank
()
==
0
:
start_time
=
time
.
time
()
print
(
'> compiling and loading fused kernels ...'
,
flush
=
True
)
fused_kernels
.
load
(
args
)
torch
.
distributed
.
barrier
()
else
:
torch
.
distributed
.
barrier
()
fused_kernels
.
load
(
args
)
# Simple barrier to make sure all ranks have passed the
# compilation phase successfully before moving on to the
# rest of the program. We think this might ensure that
# the lock is released.
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'>>> done with compiling and loading fused kernels. '
'Compilation time: {:.3f} seconds'
.
format
(
time
.
time
()
-
start_time
),
flush
=
True
)
def
_initialize_distributed
():
def
_initialize_distributed
():
"""Initialize torch.distributed and mpu."""
"""Initialize torch.distributed and mpu."""
...
...
megatron/model/__init__.py
View file @
a5acbf53
...
@@ -13,23 +13,7 @@
...
@@ -13,23 +13,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
_LAYER_NORM
=
None
from
.fused_layer_norm
import
MixedFusedLayerNorm
as
LayerNorm
def
import_layernorm
(
fp32_residual_connection
,
bf16
):
global
_LAYER_NORM
if
not
_LAYER_NORM
:
if
bf16
:
from
torch.nn
import
LayerNorm
elif
fp32_residual_connection
:
from
.fused_layer_norm
import
MixedFusedLayerNorm
as
LayerNorm
else
:
from
apex.normalization.fused_layer_norm
import
FusedLayerNorm
as
LayerNorm
_LAYER_NORM
=
LayerNorm
return
_LAYER_NORM
from
.distributed
import
*
from
.distributed
import
*
from
.bert_model
import
(
BertModel
,
from
.bert_model
import
(
BertModel
,
...
...
megatron/model/bert_model.py
View file @
a5acbf53
...
@@ -22,7 +22,7 @@ from megatron import mpu
...
@@ -22,7 +22,7 @@ from megatron import mpu
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.language_model
import
parallel_lm_logits
from
megatron.model.language_model
import
parallel_lm_logits
from
megatron.model.language_model
import
get_language_model
from
megatron.model.language_model
import
get_language_model
from
megatron.model
import
import_l
ayer
n
orm
from
megatron.model
import
L
ayer
N
orm
from
megatron.model.utils
import
openai_gelu
,
erf_gelu
from
megatron.model.utils
import
openai_gelu
,
erf_gelu
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.utils
import
init_method_normal
from
megatron.model.utils
import
init_method_normal
...
@@ -78,7 +78,6 @@ class BertLMHead(MegatronModule):
...
@@ -78,7 +78,6 @@ class BertLMHead(MegatronModule):
self
.
parallel_output
=
parallel_output
self
.
parallel_output
=
parallel_output
self
.
dense
=
get_linear_layer
(
hidden_size
,
hidden_size
,
init_method
)
self
.
dense
=
get_linear_layer
(
hidden_size
,
hidden_size
,
init_method
)
LayerNorm
=
import_layernorm
(
args
.
fp32_residual_connection
,
args
.
bf16
)
self
.
layernorm
=
LayerNorm
(
hidden_size
,
eps
=
layernorm_epsilon
)
self
.
layernorm
=
LayerNorm
(
hidden_size
,
eps
=
layernorm_epsilon
)
self
.
gelu
=
torch
.
nn
.
functional
.
gelu
self
.
gelu
=
torch
.
nn
.
functional
.
gelu
if
args
.
openai_gelu
:
if
args
.
openai_gelu
:
...
...
megatron/model/fused_layer_norm.py
View file @
a5acbf53
...
@@ -15,29 +15,23 @@
...
@@ -15,29 +15,23 @@
"""This code is copied fron NVIDIA apex:
"""This code is copied fron NVIDIA apex:
https://github.com/NVIDIA/apex
https://github.com/NVIDIA/apex
with
minor
changes. """
with
some
changes. """
import
math
import
torch
import
numbers
import
numbers
import
torch
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
torch.nn
import
init
from
torch.nn
import
init
from
torch.nn
import
functional
as
F
import
importlib
import
importlib
global
fused_layer_norm_cuda
fused_layer_norm_cuda
=
None
global
fused_mix_prec_layer_norm_cuda
global
fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda
=
None
fused_mix_prec_layer_norm_cuda
=
None
class
FusedLayerNormAffineFunction
(
torch
.
autograd
.
Function
):
class
FusedLayerNormAffineFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
input
,
weight
,
bias
,
normalized_shape
,
eps
):
def
forward
(
ctx
,
input
,
weight
,
bias
,
normalized_shape
,
eps
):
global
fused_mix_prec_layer_norm_cuda
if
fused_mix_prec_layer_norm_cuda
is
None
:
fused_mix_prec_layer_norm_cuda
=
importlib
.
import_module
(
"fused_mix_prec_layer_norm_cuda"
)
ctx
.
normalized_shape
=
normalized_shape
ctx
.
normalized_shape
=
normalized_shape
ctx
.
eps
=
eps
ctx
.
eps
=
eps
input_
=
input
.
contiguous
()
input_
=
input
.
contiguous
()
...
@@ -46,134 +40,51 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
...
@@ -46,134 +40,51 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
output
,
mean
,
invvar
=
fused_mix_prec_layer_norm_cuda
.
forward_affine
(
output
,
mean
,
invvar
=
fused_mix_prec_layer_norm_cuda
.
forward_affine
(
input_
,
ctx
.
normalized_shape
,
weight_
,
bias_
,
ctx
.
eps
)
input_
,
ctx
.
normalized_shape
,
weight_
,
bias_
,
ctx
.
eps
)
ctx
.
save_for_backward
(
input_
,
weight_
,
bias_
,
mean
,
invvar
)
ctx
.
save_for_backward
(
input_
,
weight_
,
bias_
,
mean
,
invvar
)
return
output
return
output
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_output
):
def
backward
(
ctx
,
grad_output
):
input_
,
weight_
,
bias_
,
mean
,
invvar
=
ctx
.
saved_tensors
input_
,
weight_
,
bias_
,
mean
,
invvar
=
ctx
.
saved_tensors
grad_input
=
grad_weight
=
grad_bias
=
None
grad_input
=
grad_weight
=
grad_bias
=
None
grad_input
,
grad_weight
,
grad_bias
=
fused_mix_prec_layer_norm_cuda
.
backward_affine
(
grad_input
,
grad_weight
,
grad_bias
\
=
fused_mix_prec_layer_norm_cuda
.
backward_affine
(
grad_output
.
contiguous
(),
mean
,
invvar
,
grad_output
.
contiguous
(),
mean
,
invvar
,
input_
,
ctx
.
normalized_shape
,
input_
,
ctx
.
normalized_shape
,
weight_
,
bias_
,
ctx
.
eps
)
weight_
,
bias_
,
ctx
.
eps
)
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
class
FusedLayerNormFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
def
forward
(
ctx
,
input
,
normalized_shape
,
eps
):
global
fused_layer_norm_cuda
if
fused_layer_norm_cuda
is
None
:
fused_layer_norm_cuda
=
importlib
.
import_module
(
"fused_layer_norm_cuda"
)
ctx
.
normalized_shape
=
normalized_shape
ctx
.
eps
=
eps
input_
=
input
.
contiguous
()
output
,
mean
,
invvar
=
fused_layer_norm_cuda
.
forward
(
input_
,
ctx
.
normalized_shape
,
ctx
.
eps
)
ctx
.
save_for_backward
(
input_
,
mean
,
invvar
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
input_
,
mean
,
invvar
=
ctx
.
saved_tensors
grad_input
=
None
grad_input
=
fused_layer_norm_cuda
.
backward
(
grad_output
.
contiguous
(),
mean
,
invvar
,
input_
,
ctx
.
normalized_shape
,
ctx
.
eps
)
return
grad_input
,
None
,
None
def
fused_layer_norm_affine
(
input
,
normalized_shape
,
weight
,
bias
,
eps
=
1e-6
):
return
FusedLayerNormAffineFunction
.
apply
(
input
,
weight
,
bias
,
normalized_shape
,
eps
)
def
fused_layer_norm
(
input
,
normalized_shape
,
eps
=
1e-6
):
return
FusedLayerNormFunction
.
apply
(
input
,
normalized_shape
,
eps
)
class
MixedFusedLayerNorm
(
torch
.
nn
.
Module
):
class
MixedFusedLayerNorm
(
torch
.
nn
.
Module
):
r
"""Applies Layer Normalization over a mini-batch of inputs as described in
the paper `Layer Normalization`_ .
def
__init__
(
self
,
normalized_shape
,
eps
=
1e-5
):
Currently only runs on cuda() tensors.
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The mean and standard-deviation are calculated separately over the last
certain number dimensions which have to be of the shape specified by
:attr:`normalized_shape`.
:math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
:attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
.. note::
Unlike Batch Normalization and Instance Normalization, which applies
scalar scale and bias for each entire channel/plane with the
:attr:`affine` option, Layer Normalization applies per-element scale and
bias with :attr:`elementwise_affine`.
This layer uses statistics computed from input data in both training and
evaluation modes.
Args:
normalized_shape (int or list or torch.Size): input shape from an expected input
of size
.. math::
[* \times \text{normalized}\_\text{shape}[0] \times \text{normalized}\_\text{shape}[1]
\times \ldots \times \text{normalized}\_\text{shape}[-1]]
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps: a value added to the denominator for numerical stability. Default: 1e-5
elementwise_affine: a boolean value that when set to ``True``, this module
has learnable per-element affine parameters initialized to ones (for weights)
and zeros (for biases). Default: ``True``.
Shape:
- Input: :math:`(N, *)`
- Output: :math:`(N, *)` (same shape as input)
Examples::
>>> input = torch.randn(20, 5, 10, 10)
>>> # With Learnable Parameters
>>> m = apex.normalization.FusedLayerNorm(input.size()[1:])
>>> # Without Learnable Parameters
>>> m = apex.normalization.FusedLayerNorm(input.size()[1:], elementwise_affine=False)
>>> # Normalize over last two dimensions
>>> m = apex.normalization.FusedLayerNorm([10, 10])
>>> # Normalize over last dimension of size 10
>>> m = apex.normalization.FusedLayerNorm(10)
>>> # Activating the module
>>> output = m(input)
.. _`Layer Normalization`: https://arxiv.org/abs/1607.06450
"""
def
__init__
(
self
,
normalized_shape
,
eps
=
1e-5
,
elementwise_affine
=
True
):
super
(
MixedFusedLayerNorm
,
self
).
__init__
()
super
(
MixedFusedLayerNorm
,
self
).
__init__
()
global
fused_layer_norm_cuda
fused_layer_norm_cuda
=
importlib
.
import_module
(
"fused_layer_norm_cuda"
)
global
fused_mix_prec_layer_norm_cuda
global
fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda
=
importlib
.
import_module
(
"fused_mix_prec_layer_norm_cuda"
)
fused_mix_prec_layer_norm_cuda
=
importlib
.
import_module
(
"fused_mix_prec_layer_norm_cuda"
)
if
isinstance
(
normalized_shape
,
numbers
.
Integral
):
if
isinstance
(
normalized_shape
,
numbers
.
Integral
):
normalized_shape
=
(
normalized_shape
,)
normalized_shape
=
(
normalized_shape
,)
self
.
normalized_shape
=
torch
.
Size
(
normalized_shape
)
self
.
normalized_shape
=
torch
.
Size
(
normalized_shape
)
self
.
eps
=
eps
self
.
eps
=
eps
self
.
elementwise_affine
=
elementwise_affine
self
.
weight
=
Parameter
(
torch
.
Tensor
(
*
normalized_shape
))
if
self
.
elementwise_affine
:
self
.
bias
=
Parameter
(
torch
.
Tensor
(
*
normalized_shape
))
self
.
weight
=
Parameter
(
torch
.
Tensor
(
*
normalized_shape
))
self
.
bias
=
Parameter
(
torch
.
Tensor
(
*
normalized_shape
))
else
:
self
.
register_parameter
(
'weight'
,
None
)
self
.
register_parameter
(
'bias'
,
None
)
self
.
reset_parameters
()
self
.
reset_parameters
()
def
reset_parameters
(
self
):
if
self
.
elementwise_affine
:
def
reset_parameters
(
self
):
init
.
ones_
(
self
.
weight
)
init
.
zeros_
(
self
.
bias
)
init
.
ones_
(
self
.
weight
)
init
.
zeros_
(
self
.
bias
)
def
forward
(
self
,
input
):
if
not
input
.
is_cuda
:
return
F
.
layer_norm
(
def
forward
(
self
,
input
):
input
,
self
.
normalized_shape
,
self
.
weight
,
self
.
bias
,
self
.
eps
)
return
FusedLayerNormAffineFunction
.
apply
(
if
self
.
elementwise_affine
:
input
,
self
.
weight
,
self
.
bias
,
self
.
normalized_shape
,
self
.
eps
)
return
FusedLayerNormAffineFunction
.
apply
(
input
,
self
.
weight
,
self
.
bias
,
self
.
normalized_shape
,
self
.
eps
)
else
:
return
FusedLayerNormFunction
.
apply
(
input
,
self
.
normalized_shape
,
self
.
eps
)
def
extra_repr
(
self
):
return
'{normalized_shape}, eps={eps}, '
\
'elementwise_affine={elementwise_affine}'
.
format
(
**
self
.
__dict__
)
megatron/model/fused_softmax.py
View file @
a5acbf53
...
@@ -96,6 +96,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
...
@@ -96,6 +96,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
input_in_fp16
,
input_in_fp16
,
input_in_bf16
,
attn_mask_type
,
attn_mask_type
,
scaled_masked_softmax_fusion
,
scaled_masked_softmax_fusion
,
mask_func
,
mask_func
,
...
@@ -104,6 +105,10 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
...
@@ -104,6 +105,10 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
):
):
super
(
FusedScaleMaskSoftmax
,
self
).
__init__
()
super
(
FusedScaleMaskSoftmax
,
self
).
__init__
()
self
.
input_in_fp16
=
input_in_fp16
self
.
input_in_fp16
=
input_in_fp16
self
.
input_in_bf16
=
input_in_bf16
assert
not
(
self
.
input_in_fp16
and
self
.
input_in_bf16
),
\
'both fp16 and bf16 flags cannot be active at the same time.'
self
.
input_in_float16
=
self
.
input_in_fp16
or
self
.
input_in_bf16
self
.
attn_mask_type
=
attn_mask_type
self
.
attn_mask_type
=
attn_mask_type
self
.
scaled_masked_softmax_fusion
=
scaled_masked_softmax_fusion
self
.
scaled_masked_softmax_fusion
=
scaled_masked_softmax_fusion
self
.
mask_func
=
mask_func
self
.
mask_func
=
mask_func
...
@@ -128,8 +133,8 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
...
@@ -128,8 +133,8 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
query_seq_len
%
4
==
0
and
attn_batch_size
%
4
==
0
query_seq_len
%
4
==
0
and
attn_batch_size
%
4
==
0
# invoke custom kernel
# invoke custom kernel
if
self
.
input_in_f
p
16
and
mask
is
not
None
and
\
if
self
.
input_in_f
loat
16
and
mask
is
not
None
and
\
custom_kernel_constraint
and
self
.
scaled_masked_softmax_fusion
:
custom_kernel_constraint
and
self
.
scaled_masked_softmax_fusion
:
scale
=
self
.
scale
if
self
.
scale
is
not
None
else
1.0
scale
=
self
.
scale
if
self
.
scale
is
not
None
else
1.0
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
...
@@ -142,7 +147,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
...
@@ -142,7 +147,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
assert
self
.
attn_mask_type
==
AttnMaskType
.
padding
assert
self
.
attn_mask_type
==
AttnMaskType
.
padding
probs
=
ScaledMaskedSoftmax
.
apply
(
input
,
mask
,
scale
)
probs
=
ScaledMaskedSoftmax
.
apply
(
input
,
mask
,
scale
)
else
:
else
:
if
self
.
input_in_f
p
16
and
self
.
softmax_in_fp32
:
if
self
.
input_in_f
loat
16
and
self
.
softmax_in_fp32
:
input
=
input
.
float
()
input
=
input
.
float
()
if
self
.
scale
is
not
None
:
if
self
.
scale
is
not
None
:
...
@@ -150,7 +155,10 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
...
@@ -150,7 +155,10 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
mask_output
=
self
.
mask_func
(
input
,
mask
)
if
mask
is
not
None
else
input
mask_output
=
self
.
mask_func
(
input
,
mask
)
if
mask
is
not
None
else
input
probs
=
torch
.
nn
.
Softmax
(
dim
=-
1
)(
mask_output
)
probs
=
torch
.
nn
.
Softmax
(
dim
=-
1
)(
mask_output
)
if
self
.
input_in_fp16
and
self
.
softmax_in_fp32
:
if
self
.
input_in_float16
and
self
.
softmax_in_fp32
:
probs
=
probs
.
half
()
if
self
.
input_in_fp16
:
probs
=
probs
.
half
()
else
:
probs
=
probs
.
bfloat16
()
return
probs
return
probs
megatron/model/transformer.py
View file @
a5acbf53
...
@@ -22,7 +22,7 @@ from megatron import get_args
...
@@ -22,7 +22,7 @@ from megatron import get_args
from
megatron
import
mpu
from
megatron
import
mpu
from
.module
import
MegatronModule
from
.module
import
MegatronModule
from
megatron.model.enums
import
AttnMaskType
,
LayerType
,
AttnType
from
megatron.model.enums
import
AttnMaskType
,
LayerType
,
AttnType
from
megatron.model
import
import_l
ayer
n
orm
from
megatron.model
import
L
ayer
N
orm
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
from
megatron.model.utils
import
attention_mask_func
,
openai_gelu
,
erf_gelu
from
megatron.model.utils
import
attention_mask_func
,
openai_gelu
,
erf_gelu
...
@@ -116,6 +116,7 @@ class ParallelAttention(MegatronModule):
...
@@ -116,6 +116,7 @@ class ParallelAttention(MegatronModule):
super
(
ParallelAttention
,
self
).
__init__
()
super
(
ParallelAttention
,
self
).
__init__
()
args
=
get_args
()
args
=
get_args
()
self
.
fp16
=
args
.
fp16
self
.
fp16
=
args
.
fp16
self
.
bf16
=
args
.
bf16
self
.
apply_query_key_layer_scaling
=
args
.
apply_query_key_layer_scaling
self
.
apply_query_key_layer_scaling
=
args
.
apply_query_key_layer_scaling
self
.
attention_softmax_in_fp32
=
args
.
attention_softmax_in_fp32
self
.
attention_softmax_in_fp32
=
args
.
attention_softmax_in_fp32
...
@@ -164,7 +165,7 @@ class ParallelAttention(MegatronModule):
...
@@ -164,7 +165,7 @@ class ParallelAttention(MegatronModule):
self
.
norm_factor
*=
coeff
self
.
norm_factor
*=
coeff
self
.
scale_mask_softmax
=
FusedScaleMaskSoftmax
(
self
.
scale_mask_softmax
=
FusedScaleMaskSoftmax
(
self
.
fp16
,
self
.
fp16
,
self
.
bf16
,
self
.
attn_mask_type
,
self
.
attn_mask_type
,
args
.
masked_softmax_fusion
,
args
.
masked_softmax_fusion
,
attention_mask_func
,
attention_mask_func
,
...
@@ -401,7 +402,6 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -401,7 +402,6 @@ class ParallelTransformerLayer(MegatronModule):
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
# Layernorm on the input data.
# Layernorm on the input data.
LayerNorm
=
import_layernorm
(
self
.
fp32_residual_connection
,
self
.
bf16
)
self
.
input_layernorm
=
LayerNorm
(
self
.
input_layernorm
=
LayerNorm
(
args
.
hidden_size
,
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
eps
=
args
.
layernorm_epsilon
)
...
@@ -443,8 +443,6 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -443,8 +443,6 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm at the beginning of the transformer layer.
# Layer norm at the beginning of the transformer layer.
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
if
self
.
bf16
and
self
.
fp32_residual_connection
:
layernorm_output
=
layernorm_output
.
bfloat16
()
# Self attention.
# Self attention.
attention_output
,
attention_bias
=
\
attention_output
,
attention_bias
=
\
self
.
self_attention
(
layernorm_output
,
self
.
self_attention
(
layernorm_output
,
...
@@ -483,8 +481,6 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -483,8 +481,6 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm post the self attention.
# Layer norm post the self attention.
layernorm_output
=
self
.
post_attention_layernorm
(
layernorm_input
)
layernorm_output
=
self
.
post_attention_layernorm
(
layernorm_input
)
if
self
.
bf16
and
self
.
fp32_residual_connection
:
layernorm_output
=
layernorm_output
.
bfloat16
()
if
self
.
layer_type
==
LayerType
.
decoder
:
if
self
.
layer_type
==
LayerType
.
decoder
:
attention_output
,
attention_bias
=
\
attention_output
,
attention_bias
=
\
...
@@ -507,8 +503,6 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -507,8 +503,6 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm post the decoder attention
# Layer norm post the decoder attention
layernorm_output
=
self
.
post_inter_attention_layernorm
(
layernorm_input
)
layernorm_output
=
self
.
post_inter_attention_layernorm
(
layernorm_input
)
if
self
.
bf16
and
self
.
fp32_residual_connection
:
layernorm_output
=
layernorm_output
.
bfloat16
()
# MLP.
# MLP.
mlp_output
,
mlp_bias
=
self
.
mlp
(
layernorm_output
)
mlp_output
,
mlp_bias
=
self
.
mlp
(
layernorm_output
)
...
@@ -588,8 +582,6 @@ class ParallelTransformer(MegatronModule):
...
@@ -588,8 +582,6 @@ class ParallelTransformer(MegatronModule):
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
():
# Final layer norm before output.
# Final layer norm before output.
LayerNorm
=
import_layernorm
(
self
.
fp32_residual_connection
,
self
.
bf16
)
self
.
final_layernorm
=
LayerNorm
(
self
.
final_layernorm
=
LayerNorm
(
args
.
hidden_size
,
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
eps
=
args
.
layernorm_epsilon
)
...
@@ -676,8 +668,6 @@ class ParallelTransformer(MegatronModule):
...
@@ -676,8 +668,6 @@ class ParallelTransformer(MegatronModule):
# Reverting data format change [s b h] --> [b s h].
# Reverting data format change [s b h] --> [b s h].
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
output
=
self
.
final_layernorm
(
hidden_states
)
output
=
self
.
final_layernorm
(
hidden_states
)
if
self
.
bf16
and
self
.
fp32_residual_connection
:
output
=
output
.
bfloat16
()
else
:
else
:
output
=
hidden_states
output
=
hidden_states
if
get_key_value
:
if
get_key_value
:
...
...
megatron/optimizer/__init__.py
View file @
a5acbf53
...
@@ -17,7 +17,7 @@ from apex.optimizers import FusedAdam as Adam
...
@@ -17,7 +17,7 @@ from apex.optimizers import FusedAdam as Adam
from
apex.optimizers
import
FusedSGD
as
SGD
from
apex.optimizers
import
FusedSGD
as
SGD
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron.model
import
import_l
ayer
n
orm
from
megatron.model
import
L
ayer
N
orm
from
.grad_scaler
import
ConstantGradScaler
,
DynamicGradScaler
from
.grad_scaler
import
ConstantGradScaler
,
DynamicGradScaler
from
.optimizer
import
Float16OptimizerWithFloat16Params
,
FP32Optimizer
from
.optimizer
import
Float16OptimizerWithFloat16Params
,
FP32Optimizer
...
@@ -27,8 +27,6 @@ def _get_params_for_weight_decay_optimization(modules):
...
@@ -27,8 +27,6 @@ def _get_params_for_weight_decay_optimization(modules):
"""Divide params into with-weight-decay and without-weight-decay groups.
"""Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and baises will have no weight decay but the rest will.
Layernorms and baises will have no weight decay but the rest will.
"""
"""
args
=
get_args
()
LayerNorm
=
import_layernorm
(
args
.
fp32_residual_connection
,
args
.
bf16
)
weight_decay_params
=
{
'params'
:
[]}
weight_decay_params
=
{
'params'
:
[]}
no_weight_decay_params
=
{
'params'
:
[],
'weight_decay'
:
0.0
}
no_weight_decay_params
=
{
'params'
:
[],
'weight_decay'
:
0.0
}
...
...
megatron/training.py
View file @
a5acbf53
...
@@ -224,16 +224,6 @@ def get_model(model_provider_func):
...
@@ -224,16 +224,6 @@ def get_model(model_provider_func):
# Fp16 conversion.
# Fp16 conversion.
if
args
.
fp16
or
args
.
bf16
:
if
args
.
fp16
or
args
.
bf16
:
model
=
[
Float16Module
(
model_module
,
args
)
for
model_module
in
model
]
model
=
[
Float16Module
(
model_module
,
args
)
for
model_module
in
model
]
# For now, the layer norm does not support input float32 and outut bf16.
# For this, we move layernorm parameters to fp32 and cast output of the
# layernorm operation back to bf16.
if
args
.
bf16
and
args
.
fp32_residual_connection
:
from
megatron.model
import
import_layernorm
LayerNorm
=
import_layernorm
(
args
.
fp32_residual_connection
,
args
.
bf16
)
for
model_
in
model
:
for
module_
in
model_
.
modules
():
if
isinstance
(
module_
,
LayerNorm
):
module_
.
float
()
if
args
.
DDP_impl
==
'torch'
:
if
args
.
DDP_impl
==
'torch'
:
i
=
torch
.
cuda
.
current_device
()
i
=
torch
.
cuda
.
current_device
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment