Commit d9b1c681 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'ninja_compilation_fix' into 'main'

refactored the fused kernels build

See merge request ADLR/megatron-lm!250
parents 876096d5 0d5188c1
...@@ -19,7 +19,6 @@ import argparse ...@@ -19,7 +19,6 @@ import argparse
import os import os
import torch import torch
from megatron import fused_kernels
def parse_args(extra_args_provider=None, defaults={}, def parse_args(extra_args_provider=None, defaults={},
ignore_unknown_args=False): ignore_unknown_args=False):
...@@ -227,31 +226,6 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -227,31 +226,6 @@ def parse_args(extra_args_provider=None, defaults={},
'for distribute-checkpointed-activations to work you '\ 'for distribute-checkpointed-activations to work you '\
'need to enable checkpoint-activations' 'need to enable checkpoint-activations'
# custom kernel constraints check
seq_len = args.seq_length
attn_batch_size = \
(args.num_attention_heads / args.tensor_model_parallel_size) * \
args.micro_batch_size
# constraints on sequence length and attn_batch_size to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint = seq_len > 16 and seq_len <=2048 and \
seq_len % 4 == 0 and attn_batch_size % 4 == 0
if not (args.fp16 and custom_kernel_constraint and args.masked_softmax_fusion):
print('WARNING: constraints for invoking optimized'
' fused softmax kernel are not met. We default back to unfused'
' kernel invocations.')
# Load scaled_masked_softmax_fusion_kernels
if args.masked_softmax_fusion:
fused_kernels.load_scaled_upper_triang_masked_softmax_fusion_kernel()
fused_kernels.load_scaled_masked_softmax_fusion_kernel()
# Load mixed precision fused layer norm.
if args.fp32_residual_connection:
fused_kernels.load_fused_mix_prec_layer_norm_kernel()
_print_args(args) _print_args(args)
return args return args
......
...@@ -13,114 +13,98 @@ ...@@ -13,114 +13,98 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import pathlib import pathlib
import subprocess import subprocess
import os
from torch.utils import cpp_extension from torch.utils import cpp_extension
# Setting this param to a list has a problem of generating # Setting this param to a list has a problem of generating different
# different compilation commands (with diferent order of architectures) # compilation commands (with diferent order of architectures) and
# and leading to recompilation of fused kernels. # leading to recompilation of fused kernels. Set it to empty string
# set it to empty string to avoid recompilation # to avoid recompilation and assign arch flags explicity in
# and assign arch flags explicity in extra_cuda_cflags below # extra_cuda_cflags below
os.environ["TORCH_CUDA_ARCH_LIST"] = "" os.environ["TORCH_CUDA_ARCH_LIST"] = ""
def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"],
universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
return raw_output, bare_metal_major, bare_metal_minor
def create_build_dir(buildpath):
try:
os.mkdir(buildpath)
except OSError:
if not os.path.isdir(buildpath):
print(f"Creation of the build directory {buildpath} failed")
def load_scaled_upper_triang_masked_softmax_fusion_kernel(): def load(args):
# Check, if CUDA11 is installed for compute capability 8.0 # Check if cuda 11 is installed for compute capability 8.0
cc_flag = [] cc_flag = []
_, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME) _, bare_metal_major, _ = _get_cuda_bare_metal_version(
cpp_extension.CUDA_HOME)
if int(bare_metal_major) >= 11: if int(bare_metal_major) >= 11:
cc_flag.append('-gencode') cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80') cc_flag.append('arch=compute_80,code=sm_80')
# Build path
srcpath = pathlib.Path(__file__).parent.absolute() srcpath = pathlib.Path(__file__).parent.absolute()
buildpath = srcpath / 'build' buildpath = srcpath / 'build'
_create_build_dir(buildpath)
create_build_dir(buildpath) # Helper function to build the kernels.
def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
scaled_upper_triang_masked_softmax_cuda = cpp_extension.load( return cpp_extension.load(
name='scaled_upper_triang_masked_softmax_cuda', name=name,
sources=[srcpath / 'scaled_upper_triang_masked_softmax.cpp', sources=sources,
srcpath / 'scaled_upper_triang_masked_softmax_cuda.cu'],
build_directory=buildpath, build_directory=buildpath,
extra_cflags=['-O3',], extra_cflags=['-O3',],
extra_cuda_cflags=['-O3', extra_cuda_cflags=['-O3',
'-gencode', 'arch=compute_70,code=sm_70', '-gencode', 'arch=compute_70,code=sm_70',
'-U__CUDA_NO_HALF_OPERATORS__', '--use_fast_math'] + extra_cuda_flags + cc_flag,
'-U__CUDA_NO_HALF_CONVERSIONS__', verbose=(args.rank == 0)
'--expt-relaxed-constexpr', )
'--expt-extended-lambda',
'--use_fast_math'] + cc_flag)
def load_scaled_masked_softmax_fusion_kernel(): # ==============
# Fused softmax.
# Check, if CUDA11 is installed for compute capability 8.0 # ==============
cc_flag = []
_, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
if int(bare_metal_major) >= 11:
cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80')
srcpath = pathlib.Path(__file__).parent.absolute() if args.masked_softmax_fusion:
buildpath = srcpath / 'build' extra_cuda_flags = ['-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda']
create_build_dir(buildpath) # Upper triangular softmax.
sources=[srcpath / 'scaled_upper_triang_masked_softmax.cpp',
srcpath / 'scaled_upper_triang_masked_softmax_cuda.cu']
scaled_upper_triang_masked_softmax_cuda = _cpp_extention_load_helper(
"scaled_upper_triang_masked_softmax_cuda",
sources, extra_cuda_flags)
scaled_upper_triang_masked_softmax_cuda = cpp_extension.load( # Masked softmax.
name='scaled_masked_softmax_cuda',
sources=[srcpath / 'scaled_masked_softmax.cpp', sources=[srcpath / 'scaled_masked_softmax.cpp',
srcpath / 'scaled_masked_softmax_cuda.cu'], srcpath / 'scaled_masked_softmax_cuda.cu']
build_directory=buildpath, scaled_masked_softmax_cuda = _cpp_extention_load_helper(
extra_cflags=['-O3',], "scaled_masked_softmax_cuda", sources, extra_cuda_flags)
extra_cuda_cflags=['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + cc_flag)
# =================================
# Mixed precision fused layer norm.
# =================================
def load_fused_mix_prec_layer_norm_kernel(): if args.fp32_residual_connection:
extra_cuda_flags = ['-maxrregcount=50']
sources=[srcpath / 'layer_norm_cuda.cpp',
srcpath / 'layer_norm_cuda_kernel.cu']
fused_mix_prec_layer_norm_cuda = _cpp_extention_load_helper(
"fused_mix_prec_layer_norm_cuda", sources, extra_cuda_flags)
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag = []
_, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
if int(bare_metal_major) >= 11:
cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80')
srcpath = pathlib.Path(__file__).parent.absolute() def _get_cuda_bare_metal_version(cuda_dir):
buildpath = srcpath / 'build' raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"],
universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
create_build_dir(buildpath) return raw_output, bare_metal_major, bare_metal_minor
fused_mix_prec_layer_norm_cuda = cpp_extension.load(
name='fused_mix_prec_layer_norm_cuda', def _create_build_dir(buildpath):
sources=[srcpath / 'layer_norm_cuda.cpp', try:
srcpath / 'layer_norm_cuda_kernel.cu'], os.mkdir(buildpath)
build_directory=buildpath, except OSError:
extra_cflags=['-O3'], if not os.path.isdir(buildpath):
extra_cuda_cflags=['-O3', print(f"Creation of the build directory {buildpath} failed")
'-gencode', 'arch=compute_70,code=sm_70',
'-maxrregcount=50',
'--use_fast_math'] + cc_flag)
...@@ -26,11 +26,7 @@ ...@@ -26,11 +26,7 @@
namespace { namespace {
void compute_n1_n2( void compute_n1_n2(
at::Tensor input, at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape, at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
int& n1, int& n1,
int& n2) int& n2)
{ {
...@@ -47,11 +43,7 @@ void compute_n1_n2( ...@@ -47,11 +43,7 @@ void compute_n1_n2(
} }
void check_args( void check_args(
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape, at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor gamma, at::Tensor gamma,
at::Tensor beta at::Tensor beta
) )
...@@ -62,11 +54,7 @@ void check_args( ...@@ -62,11 +54,7 @@ void check_args(
void check_args( void check_args(
at::Tensor input, at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape, at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
int& n1, int& n1,
int& n2 int& n2
) )
...@@ -102,11 +90,7 @@ void check_args( ...@@ -102,11 +90,7 @@ void check_args(
void check_args( void check_args(
at::Tensor input, at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape, at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor gamma, at::Tensor gamma,
at::Tensor beta, at::Tensor beta,
int& n1, int& n1,
...@@ -125,26 +109,18 @@ void cuda_layer_norm( ...@@ -125,26 +109,18 @@ void cuda_layer_norm(
at::Tensor* input, at::Tensor* input,
int n1, int n1,
int n2, int n2,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape, at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor* gamma, at::Tensor* gamma,
at::Tensor* beta, at::Tensor* beta,
double epsilon); double epsilon);
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<at::Tensor> layer_norm( std::vector<at::Tensor> layer_norm(
at::Tensor input, at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape, at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
double epsilon) { double epsilon) {
CHECK_INPUT(input); CHECK_INPUT(input);
int n1,n2; int n1,n2;
...@@ -158,11 +134,7 @@ std::vector<at::Tensor> layer_norm( ...@@ -158,11 +134,7 @@ std::vector<at::Tensor> layer_norm(
} }
std::vector<at::Tensor> layer_norm_affine( std::vector<at::Tensor> layer_norm_affine(
at::Tensor input, at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape, at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor gamma, at::Tensor gamma,
at::Tensor beta, at::Tensor beta,
double epsilon) { double epsilon) {
...@@ -186,11 +158,7 @@ void cuda_layer_norm_gradient( ...@@ -186,11 +158,7 @@ void cuda_layer_norm_gradient(
at::Tensor* input, at::Tensor* input,
int n1, int n1,
int n2, int n2,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape, at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor* gamma, at::Tensor* gamma,
at::Tensor* beta, at::Tensor* beta,
double epsilon, double epsilon,
...@@ -204,11 +172,7 @@ at::Tensor layer_norm_gradient( ...@@ -204,11 +172,7 @@ at::Tensor layer_norm_gradient(
at::Tensor mean, at::Tensor mean,
at::Tensor invvar, at::Tensor invvar,
at::Tensor input, at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape, at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
double epsilon) { double epsilon) {
CHECK_INPUT(dout); CHECK_INPUT(dout);
CHECK_INPUT(mean); CHECK_INPUT(mean);
...@@ -227,11 +191,7 @@ std::vector<at::Tensor> layer_norm_gradient_affine( ...@@ -227,11 +191,7 @@ std::vector<at::Tensor> layer_norm_gradient_affine(
at::Tensor mean, at::Tensor mean,
at::Tensor invvar, at::Tensor invvar,
at::Tensor input, at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape, at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor gamma, at::Tensor gamma,
at::Tensor beta, at::Tensor beta,
double epsilon) { double epsilon) {
......
...@@ -19,7 +19,6 @@ ...@@ -19,7 +19,6 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_profiler_api.h> #include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "scaled_masked_softmax.h" #include "scaled_masked_softmax.h"
......
...@@ -19,7 +19,6 @@ ...@@ -19,7 +19,6 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_profiler_api.h> #include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "scaled_upper_triang_masked_softmax.h" #include "scaled_upper_triang_masked_softmax.h"
......
...@@ -17,16 +17,20 @@ ...@@ -17,16 +17,20 @@
import random import random
import os import os
import time
import numpy as np import numpy as np
import torch import torch
from megatron import fused_kernels
from megatron import get_adlr_autoresume from megatron import get_adlr_autoresume
from megatron import get_args from megatron import get_args
from megatron import get_tensorboard_writer from megatron import get_tensorboard_writer
from megatron import mpu from megatron import mpu
from megatron.global_vars import set_global_variables from megatron.global_vars import set_global_variables
from megatron.mpu import set_tensor_model_parallel_rank, set_tensor_model_parallel_world_size from megatron.mpu import (set_tensor_model_parallel_rank,
set_tensor_model_parallel_world_size)
def initialize_megatron(extra_args_provider=None, args_defaults={}, def initialize_megatron(extra_args_provider=None, args_defaults={},
ignore_unknown_args=False, allow_no_cuda=False): ignore_unknown_args=False, allow_no_cuda=False):
...@@ -37,8 +41,7 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}, ...@@ -37,8 +41,7 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
what you are doing. what you are doing.
Returns a function to finalize distributed env initialization Returns a function to finalize distributed env initialization
(optionally, only when args.lazy_mpu_init == True) (optionally, only when args.lazy_mpu_init == True)
"""
"""
if not allow_no_cuda: if not allow_no_cuda:
# Make sure cuda is available. # Make sure cuda is available.
assert torch.cuda.is_available(), 'Megatron requires CUDA.' assert torch.cuda.is_available(), 'Megatron requires CUDA.'
...@@ -66,7 +69,8 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}, ...@@ -66,7 +69,8 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
# delayed initialization of DDP-related stuff # delayed initialization of DDP-related stuff
# We only set basic DDP globals # We only set basic DDP globals
set_tensor_model_parallel_world_size(args.tensor_model_parallel_size) set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
# and return function for external DDP manager to call when it has DDP initialized # and return function for external DDP manager
# to call when it has DDP initialized
set_tensor_model_parallel_rank(args.rank) set_tensor_model_parallel_rank(args.rank)
return finish_mpu_init return finish_mpu_init
else: else:
...@@ -79,15 +83,70 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}, ...@@ -79,15 +83,70 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
# Autoresume. # Autoresume.
_init_autoresume() _init_autoresume()
# Compile dependencies.
_compile_dependencies()
# No continuation function
return None
def _compile_dependencies():
args = get_args()
# =========================
# Compile dataset C++ code. # Compile dataset C++ code.
# =========================
# TODO: move this to ninja
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
start_time = time.time()
print('> compiling dataset index builder ...')
from megatron.data.dataset_utils import compile_helper from megatron.data.dataset_utils import compile_helper
compile_helper() compile_helper()
# Simple barrier print('>>> done with dataset index builder. Compilation time: {:.3f} '
'seconds'.format(time.time() - start_time), flush=True)
# ==================
# Load fused kernels
# ==================
# Custom kernel constraints check.
seq_len = args.seq_length
attn_batch_size = \
(args.num_attention_heads / args.tensor_model_parallel_size) * \
args.micro_batch_size
# Constraints on sequence length and attn_batch_size to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint = seq_len > 16 and seq_len <=2048 and \
seq_len % 4 == 0 and attn_batch_size % 4 == 0
# Print a warning.
if not ((args.fp16 or args.bf16) and
custom_kernel_constraint and
args.masked_softmax_fusion):
if args.rank == 0:
print('WARNING: constraints for invoking optimized'
' fused softmax kernel are not met. We default'
' back to unfused kernel invocations.', flush=True)
# Always build on rank zero first.
if torch.distributed.get_rank() == 0:
start_time = time.time()
print('> compiling and loading fused kernels ...', flush=True)
fused_kernels.load(args)
torch.distributed.barrier()
else:
torch.distributed.barrier() torch.distributed.barrier()
fused_kernels.load(args)
# Simple barrier to make sure all ranks have passed the
# compilation phase successfully before moving on to the
# rest of the program. We think this might ensure that
# the lock is released.
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print('>>> done with compiling and loading fused kernels. '
'Compilation time: {:.3f} seconds'.format(
time.time() - start_time), flush=True)
# No continuation function
return None
def _initialize_distributed(): def _initialize_distributed():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment