Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
0d5188c1
Commit
0d5188c1
authored
Mar 17, 2021
by
Mohammad Shoeybi
Committed by
Jared Casper
Mar 17, 2021
Browse files
refactored the fused kernels build
parent
876096d5
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
143 additions
and
168 deletions
+143
-168
megatron/arguments.py
megatron/arguments.py
+0
-26
megatron/fused_kernels/__init__.py
megatron/fused_kernels/__init__.py
+71
-87
megatron/fused_kernels/layer_norm_cuda.cpp
megatron/fused_kernels/layer_norm_cuda.cpp
+1
-41
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
+0
-1
megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu
.../fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu
+0
-1
megatron/initialize.py
megatron/initialize.py
+71
-12
No files found.
megatron/arguments.py
View file @
0d5188c1
...
@@ -19,7 +19,6 @@ import argparse
...
@@ -19,7 +19,6 @@ import argparse
import
os
import
os
import
torch
import
torch
from
megatron
import
fused_kernels
def
parse_args
(
extra_args_provider
=
None
,
defaults
=
{},
def
parse_args
(
extra_args_provider
=
None
,
defaults
=
{},
ignore_unknown_args
=
False
):
ignore_unknown_args
=
False
):
...
@@ -227,31 +226,6 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -227,31 +226,6 @@ def parse_args(extra_args_provider=None, defaults={},
'for distribute-checkpointed-activations to work you '
\
'for distribute-checkpointed-activations to work you '
\
'need to enable checkpoint-activations'
'need to enable checkpoint-activations'
# custom kernel constraints check
seq_len
=
args
.
seq_length
attn_batch_size
=
\
(
args
.
num_attention_heads
/
args
.
tensor_model_parallel_size
)
*
\
args
.
micro_batch_size
# constraints on sequence length and attn_batch_size to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint
=
seq_len
>
16
and
seq_len
<=
2048
and
\
seq_len
%
4
==
0
and
attn_batch_size
%
4
==
0
if
not
(
args
.
fp16
and
custom_kernel_constraint
and
args
.
masked_softmax_fusion
):
print
(
'WARNING: constraints for invoking optimized'
' fused softmax kernel are not met. We default back to unfused'
' kernel invocations.'
)
# Load scaled_masked_softmax_fusion_kernels
if
args
.
masked_softmax_fusion
:
fused_kernels
.
load_scaled_upper_triang_masked_softmax_fusion_kernel
()
fused_kernels
.
load_scaled_masked_softmax_fusion_kernel
()
# Load mixed precision fused layer norm.
if
args
.
fp32_residual_connection
:
fused_kernels
.
load_fused_mix_prec_layer_norm_kernel
()
_print_args
(
args
)
_print_args
(
args
)
return
args
return
args
...
...
megatron/fused_kernels/__init__.py
View file @
0d5188c1
...
@@ -13,114 +13,98 @@
...
@@ -13,114 +13,98 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
os
import
pathlib
import
pathlib
import
subprocess
import
subprocess
import
os
from
torch.utils
import
cpp_extension
from
torch.utils
import
cpp_extension
# Setting this param to a list has a problem of generating
# Setting this param to a list has a problem of generating
different
#
different
compilation commands (with diferent order of architectures)
# compilation commands (with diferent order of architectures)
and
#
and
leading to recompilation of fused kernels.
# leading to recompilation of fused kernels.
Set it to empty string
#
set it to empty string to avoid recompilatio
n
#
to avoid recompilation and assign arch flags explicity i
n
#
and assign arch flags explicity in
extra_cuda_cflags below
# extra_cuda_cflags below
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
""
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
""
def
get_cuda_bare_metal_version
(
cuda_dir
):
raw_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
output
=
raw_output
.
split
()
release_idx
=
output
.
index
(
"release"
)
+
1
release
=
output
[
release_idx
].
split
(
"."
)
bare_metal_major
=
release
[
0
]
bare_metal_minor
=
release
[
1
][
0
]
return
raw_output
,
bare_metal_major
,
bare_metal_minor
def
load
(
args
):
def
create_build_dir
(
buildpath
):
# Check if cuda 11 is installed for compute capability 8.0
try
:
os
.
mkdir
(
buildpath
)
except
OSError
:
if
not
os
.
path
.
isdir
(
buildpath
):
print
(
f
"Creation of the build directory
{
buildpath
}
failed"
)
def
load_scaled_upper_triang_masked_softmax_fusion_kernel
():
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag
=
[]
cc_flag
=
[]
_
,
bare_metal_major
,
_
=
get_cuda_bare_metal_version
(
cpp_extension
.
CUDA_HOME
)
_
,
bare_metal_major
,
_
=
_get_cuda_bare_metal_version
(
cpp_extension
.
CUDA_HOME
)
if
int
(
bare_metal_major
)
>=
11
:
if
int
(
bare_metal_major
)
>=
11
:
cc_flag
.
append
(
'-gencode'
)
cc_flag
.
append
(
'-gencode'
)
cc_flag
.
append
(
'arch=compute_80,code=sm_80'
)
cc_flag
.
append
(
'arch=compute_80,code=sm_80'
)
# Build path
srcpath
=
pathlib
.
Path
(
__file__
).
parent
.
absolute
()
srcpath
=
pathlib
.
Path
(
__file__
).
parent
.
absolute
()
buildpath
=
srcpath
/
'build'
buildpath
=
srcpath
/
'build'
_create_build_dir
(
buildpath
)
create_build_dir
(
buildpath
)
# Helper function to build the kernels.
scaled_upper_triang_masked_softmax_cuda
=
cpp_extension
.
load
(
def
_cpp_extention_load_helper
(
name
,
sources
,
extra_cuda_flags
):
name
=
'scaled_upper_triang_masked_softmax_cuda'
,
return
cpp_extension
.
load
(
name
=
name
,
sources
=
sources
,
build_directory
=
buildpath
,
extra_cflags
=
[
'-O3'
,],
extra_cuda_cflags
=
[
'-O3'
,
'-gencode'
,
'arch=compute_70,code=sm_70'
,
'--use_fast_math'
]
+
extra_cuda_flags
+
cc_flag
,
verbose
=
(
args
.
rank
==
0
)
)
# ==============
# Fused softmax.
# ==============
if
args
.
masked_softmax_fusion
:
extra_cuda_flags
=
[
'-U__CUDA_NO_HALF_OPERATORS__'
,
'-U__CUDA_NO_HALF_CONVERSIONS__'
,
'--expt-relaxed-constexpr'
,
'--expt-extended-lambda'
]
# Upper triangular softmax.
sources
=
[
srcpath
/
'scaled_upper_triang_masked_softmax.cpp'
,
sources
=
[
srcpath
/
'scaled_upper_triang_masked_softmax.cpp'
,
srcpath
/
'scaled_upper_triang_masked_softmax_cuda.cu'
],
srcpath
/
'scaled_upper_triang_masked_softmax_cuda.cu'
]
build_directory
=
buildpath
,
scaled_upper_triang_masked_softmax_cuda
=
_cpp_extention_load_helper
(
extra_cflags
=
[
'-O3'
,],
"scaled_upper_triang_masked_softmax_cuda"
,
extra_cuda_cflags
=
[
'-O3'
,
sources
,
extra_cuda_flags
)
'-gencode'
,
'arch=compute_70,code=sm_70'
,
'-U__CUDA_NO_HALF_OPERATORS__'
,
'-U__CUDA_NO_HALF_CONVERSIONS__'
,
'--expt-relaxed-constexpr'
,
'--expt-extended-lambda'
,
'--use_fast_math'
]
+
cc_flag
)
def
load_scaled_masked_softmax_fusion_kernel
():
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag
=
[]
_
,
bare_metal_major
,
_
=
get_cuda_bare_metal_version
(
cpp_extension
.
CUDA_HOME
)
if
int
(
bare_metal_major
)
>=
11
:
cc_flag
.
append
(
'-gencode'
)
cc_flag
.
append
(
'arch=compute_80,code=sm_80'
)
srcpath
=
pathlib
.
Path
(
__file__
).
parent
.
absolute
()
buildpath
=
srcpath
/
'build'
create_build_dir
(
buildpath
)
# Masked softmax.
scaled_upper_triang_masked_softmax_cuda
=
cpp_extension
.
load
(
name
=
'scaled_masked_softmax_cuda'
,
sources
=
[
srcpath
/
'scaled_masked_softmax.cpp'
,
sources
=
[
srcpath
/
'scaled_masked_softmax.cpp'
,
srcpath
/
'scaled_masked_softmax_cuda.cu'
],
srcpath
/
'scaled_masked_softmax_cuda.cu'
]
build_directory
=
buildpath
,
scaled_masked_softmax_cuda
=
_cpp_extention_load_helper
(
extra_cflags
=
[
'-O3'
,],
"scaled_masked_softmax_cuda"
,
sources
,
extra_cuda_flags
)
extra_cuda_cflags
=
[
'-O3'
,
'-gencode'
,
'arch=compute_70,code=sm_70'
,
'-U__CUDA_NO_HALF_OPERATORS__'
,
'-U__CUDA_NO_HALF_CONVERSIONS__'
,
'--expt-relaxed-constexpr'
,
'--expt-extended-lambda'
,
'--use_fast_math'
]
+
cc_flag
)
# =================================
# Mixed precision fused layer norm.
# =================================
def
load_fused_mix_prec_layer_norm_kernel
():
if
args
.
fp32_residual_connection
:
extra_cuda_flags
=
[
'-maxrregcount=50'
]
sources
=
[
srcpath
/
'layer_norm_cuda.cpp'
,
srcpath
/
'layer_norm_cuda_kernel.cu'
]
fused_mix_prec_layer_norm_cuda
=
_cpp_extention_load_helper
(
"fused_mix_prec_layer_norm_cuda"
,
sources
,
extra_cuda_flags
)
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag
=
[]
_
,
bare_metal_major
,
_
=
get_cuda_bare_metal_version
(
cpp_extension
.
CUDA_HOME
)
if
int
(
bare_metal_major
)
>=
11
:
cc_flag
.
append
(
'-gencode'
)
cc_flag
.
append
(
'arch=compute_80,code=sm_80'
)
srcpath
=
pathlib
.
Path
(
__file__
).
parent
.
absolute
()
def
_get_cuda_bare_metal_version
(
cuda_dir
):
buildpath
=
srcpath
/
'build'
raw_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
output
=
raw_output
.
split
()
release_idx
=
output
.
index
(
"release"
)
+
1
release
=
output
[
release_idx
].
split
(
"."
)
bare_metal_major
=
release
[
0
]
bare_metal_minor
=
release
[
1
][
0
]
c
re
ate_build_dir
(
buildpath
)
re
turn
raw_output
,
bare_metal_major
,
bare_metal_minor
fused_mix_prec_layer_norm_cuda
=
cpp_extension
.
load
(
name
=
'fused_mix_prec_layer_norm_cuda'
,
def
_create_build_dir
(
buildpath
):
sources
=
[
srcpath
/
'layer_norm_cuda.cpp'
,
try
:
srcpath
/
'layer_norm_cuda_kernel.cu'
],
os
.
mkdir
(
buildpath
)
build_directory
=
buildpath
,
except
OSError
:
extra_cflags
=
[
'-O3'
],
if
not
os
.
path
.
isdir
(
buildpath
):
extra_cuda_cflags
=
[
'-O3'
,
print
(
f
"Creation of the build directory
{
buildpath
}
failed"
)
'-gencode'
,
'arch=compute_70,code=sm_70'
,
'-maxrregcount=50'
,
'--use_fast_math'
]
+
cc_flag
)
megatron/fused_kernels/layer_norm_cuda.cpp
View file @
0d5188c1
...
@@ -26,11 +26,7 @@
...
@@ -26,11 +26,7 @@
namespace
{
namespace
{
void
compute_n1_n2
(
void
compute_n1_n2
(
at
::
Tensor
input
,
at
::
Tensor
input
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
int
&
n1
,
int
&
n1
,
int
&
n2
)
int
&
n2
)
{
{
...
@@ -47,11 +43,7 @@ void compute_n1_n2(
...
@@ -47,11 +43,7 @@ void compute_n1_n2(
}
}
void
check_args
(
void
check_args
(
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
at
::
Tensor
gamma
,
at
::
Tensor
gamma
,
at
::
Tensor
beta
at
::
Tensor
beta
)
)
...
@@ -62,11 +54,7 @@ void check_args(
...
@@ -62,11 +54,7 @@ void check_args(
void
check_args
(
void
check_args
(
at
::
Tensor
input
,
at
::
Tensor
input
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
int
&
n1
,
int
&
n1
,
int
&
n2
int
&
n2
)
)
...
@@ -102,11 +90,7 @@ void check_args(
...
@@ -102,11 +90,7 @@ void check_args(
void
check_args
(
void
check_args
(
at
::
Tensor
input
,
at
::
Tensor
input
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
at
::
Tensor
gamma
,
at
::
Tensor
gamma
,
at
::
Tensor
beta
,
at
::
Tensor
beta
,
int
&
n1
,
int
&
n1
,
...
@@ -125,26 +109,18 @@ void cuda_layer_norm(
...
@@ -125,26 +109,18 @@ void cuda_layer_norm(
at
::
Tensor
*
input
,
at
::
Tensor
*
input
,
int
n1
,
int
n1
,
int
n2
,
int
n2
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
at
::
Tensor
*
gamma
,
at
::
Tensor
*
gamma
,
at
::
Tensor
*
beta
,
at
::
Tensor
*
beta
,
double
epsilon
);
double
epsilon
);
#define CHECK_CUDA(x) TORCH_CHECK(x.
type().
is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std
::
vector
<
at
::
Tensor
>
layer_norm
(
std
::
vector
<
at
::
Tensor
>
layer_norm
(
at
::
Tensor
input
,
at
::
Tensor
input
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
double
epsilon
)
{
double
epsilon
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input
);
int
n1
,
n2
;
int
n1
,
n2
;
...
@@ -158,11 +134,7 @@ std::vector<at::Tensor> layer_norm(
...
@@ -158,11 +134,7 @@ std::vector<at::Tensor> layer_norm(
}
}
std
::
vector
<
at
::
Tensor
>
layer_norm_affine
(
std
::
vector
<
at
::
Tensor
>
layer_norm_affine
(
at
::
Tensor
input
,
at
::
Tensor
input
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
at
::
Tensor
gamma
,
at
::
Tensor
gamma
,
at
::
Tensor
beta
,
at
::
Tensor
beta
,
double
epsilon
)
{
double
epsilon
)
{
...
@@ -186,11 +158,7 @@ void cuda_layer_norm_gradient(
...
@@ -186,11 +158,7 @@ void cuda_layer_norm_gradient(
at
::
Tensor
*
input
,
at
::
Tensor
*
input
,
int
n1
,
int
n1
,
int
n2
,
int
n2
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
at
::
Tensor
*
gamma
,
at
::
Tensor
*
gamma
,
at
::
Tensor
*
beta
,
at
::
Tensor
*
beta
,
double
epsilon
,
double
epsilon
,
...
@@ -204,11 +172,7 @@ at::Tensor layer_norm_gradient(
...
@@ -204,11 +172,7 @@ at::Tensor layer_norm_gradient(
at
::
Tensor
mean
,
at
::
Tensor
mean
,
at
::
Tensor
invvar
,
at
::
Tensor
invvar
,
at
::
Tensor
input
,
at
::
Tensor
input
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
double
epsilon
)
{
double
epsilon
)
{
CHECK_INPUT
(
dout
);
CHECK_INPUT
(
dout
);
CHECK_INPUT
(
mean
);
CHECK_INPUT
(
mean
);
...
@@ -227,11 +191,7 @@ std::vector<at::Tensor> layer_norm_gradient_affine(
...
@@ -227,11 +191,7 @@ std::vector<at::Tensor> layer_norm_gradient_affine(
at
::
Tensor
mean
,
at
::
Tensor
mean
,
at
::
Tensor
invvar
,
at
::
Tensor
invvar
,
at
::
Tensor
input
,
at
::
Tensor
input
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
at
::
Tensor
gamma
,
at
::
Tensor
gamma
,
at
::
Tensor
beta
,
at
::
Tensor
beta
,
double
epsilon
)
{
double
epsilon
)
{
...
...
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
View file @
0d5188c1
...
@@ -19,7 +19,6 @@
...
@@ -19,7 +19,6 @@
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <torch/extension.h>
#include "scaled_masked_softmax.h"
#include "scaled_masked_softmax.h"
...
...
megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu
View file @
0d5188c1
...
@@ -19,7 +19,6 @@
...
@@ -19,7 +19,6 @@
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <torch/extension.h>
#include "scaled_upper_triang_masked_softmax.h"
#include "scaled_upper_triang_masked_softmax.h"
...
...
megatron/initialize.py
View file @
0d5188c1
...
@@ -17,16 +17,20 @@
...
@@ -17,16 +17,20 @@
import
random
import
random
import
os
import
os
import
time
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
megatron
import
fused_kernels
from
megatron
import
get_adlr_autoresume
from
megatron
import
get_adlr_autoresume
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
get_tensorboard_writer
from
megatron
import
get_tensorboard_writer
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.global_vars
import
set_global_variables
from
megatron.global_vars
import
set_global_variables
from
megatron.mpu
import
set_tensor_model_parallel_rank
,
set_tensor_model_parallel_world_size
from
megatron.mpu
import
(
set_tensor_model_parallel_rank
,
set_tensor_model_parallel_world_size
)
def
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{},
def
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{},
ignore_unknown_args
=
False
,
allow_no_cuda
=
False
):
ignore_unknown_args
=
False
,
allow_no_cuda
=
False
):
...
@@ -37,8 +41,7 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
...
@@ -37,8 +41,7 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
what you are doing.
what you are doing.
Returns a function to finalize distributed env initialization
Returns a function to finalize distributed env initialization
(optionally, only when args.lazy_mpu_init == True)
(optionally, only when args.lazy_mpu_init == True)
"""
"""
if
not
allow_no_cuda
:
if
not
allow_no_cuda
:
# Make sure cuda is available.
# Make sure cuda is available.
assert
torch
.
cuda
.
is_available
(),
'Megatron requires CUDA.'
assert
torch
.
cuda
.
is_available
(),
'Megatron requires CUDA.'
...
@@ -66,7 +69,8 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
...
@@ -66,7 +69,8 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
# delayed initialization of DDP-related stuff
# delayed initialization of DDP-related stuff
# We only set basic DDP globals
# We only set basic DDP globals
set_tensor_model_parallel_world_size
(
args
.
tensor_model_parallel_size
)
set_tensor_model_parallel_world_size
(
args
.
tensor_model_parallel_size
)
# and return function for external DDP manager to call when it has DDP initialized
# and return function for external DDP manager
# to call when it has DDP initialized
set_tensor_model_parallel_rank
(
args
.
rank
)
set_tensor_model_parallel_rank
(
args
.
rank
)
return
finish_mpu_init
return
finish_mpu_init
else
:
else
:
...
@@ -79,16 +83,71 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
...
@@ -79,16 +83,71 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
# Autoresume.
# Autoresume.
_init_autoresume
()
_init_autoresume
()
# Compile dataset C++ code.
# Compile dependencies.
if
torch
.
distributed
.
get_rank
()
==
0
:
_compile_dependencies
()
from
megatron.data.dataset_utils
import
compile_helper
compile_helper
()
# Simple barrier
torch
.
distributed
.
barrier
()
# No continuation function
# No continuation function
return
None
return
None
def
_compile_dependencies
():
args
=
get_args
()
# =========================
# Compile dataset C++ code.
# =========================
# TODO: move this to ninja
if
torch
.
distributed
.
get_rank
()
==
0
:
start_time
=
time
.
time
()
print
(
'> compiling dataset index builder ...'
)
from
megatron.data.dataset_utils
import
compile_helper
compile_helper
()
print
(
'>>> done with dataset index builder. Compilation time: {:.3f} '
'seconds'
.
format
(
time
.
time
()
-
start_time
),
flush
=
True
)
# ==================
# Load fused kernels
# ==================
# Custom kernel constraints check.
seq_len
=
args
.
seq_length
attn_batch_size
=
\
(
args
.
num_attention_heads
/
args
.
tensor_model_parallel_size
)
*
\
args
.
micro_batch_size
# Constraints on sequence length and attn_batch_size to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint
=
seq_len
>
16
and
seq_len
<=
2048
and
\
seq_len
%
4
==
0
and
attn_batch_size
%
4
==
0
# Print a warning.
if
not
((
args
.
fp16
or
args
.
bf16
)
and
custom_kernel_constraint
and
args
.
masked_softmax_fusion
):
if
args
.
rank
==
0
:
print
(
'WARNING: constraints for invoking optimized'
' fused softmax kernel are not met. We default'
' back to unfused kernel invocations.'
,
flush
=
True
)
# Always build on rank zero first.
if
torch
.
distributed
.
get_rank
()
==
0
:
start_time
=
time
.
time
()
print
(
'> compiling and loading fused kernels ...'
,
flush
=
True
)
fused_kernels
.
load
(
args
)
torch
.
distributed
.
barrier
()
else
:
torch
.
distributed
.
barrier
()
fused_kernels
.
load
(
args
)
# Simple barrier to make sure all ranks have passed the
# compilation phase successfully before moving on to the
# rest of the program. We think this might ensure that
# the lock is released.
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'>>> done with compiling and loading fused kernels. '
'Compilation time: {:.3f} seconds'
.
format
(
time
.
time
()
-
start_time
),
flush
=
True
)
def
_initialize_distributed
():
def
_initialize_distributed
():
"""Initialize torch.distributed and mpu."""
"""Initialize torch.distributed and mpu."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment