Unverified Commit 365fdc18 authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

transformer utils (#1181)


Co-authored-by: default avatarPiotr Bialecki <pbialecki@nvidia.com>
Co-authored-by: default avatarEddie Yan <eddiey@nvidia.com>
Co-authored-by: default avatarRishi Puri <riship@nvidia.com>
Co-authored-by: default avatarSangkug Lym <slym@nvidia.com>
parent bdac244e
This diff is collapsed.
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import numpy
import torch
from apex import transformer
from apex.transformer.tensor_parallel.tests import global_vars
TEST_SUCCESS_MESSAGE = ">> passed the test :-)"
class IdentityLayer(torch.nn.Module):
def __init__(self, size, scale=1.0):
super(IdentityLayer, self).__init__()
self.weight = torch.nn.Parameter(scale * torch.randn(size))
def forward(self):
return self.weight
def set_random_seed(seed):
"""Set random seed for reproducibility."""
random.seed(seed)
numpy.random.seed(seed)
torch.manual_seed(seed)
transformer.tensor_parallel.model_parallel_cuda_manual_seed(seed)
def initialize_distributed(backend='nccl'):
"""Initialize torch.distributed."""
# Get local rank in case it is provided.
# parser = argparse.ArgumentParser()
# parser.add_argument('--local_rank', type=int, default=None,
# help='local rank passed from distributed launcher')
# args = parser.parse_args()
args = global_vars.get_args()
local_rank = args.local_rank
# Get rank and world size.
rank = int(os.getenv('RANK', '0'))
world_size = int(os.getenv("WORLD_SIZE", '1'))
print('> initializing torch.distributed with local rank: {}, '
'rank: {}, world size: {}'.format(local_rank, rank, world_size))
# Set the device id.
device = rank % torch.cuda.device_count()
if local_rank is not None:
device = local_rank
torch.cuda.set_device(device)
# Call the init process.
init_method = 'tcp://'
master_ip = os.getenv('MASTER_ADDR', 'localhost')
master_port = os.getenv('MASTER_PORT', '6000')
init_method += master_ip + ':' + master_port
torch.distributed.init_process_group(
backend=backend,
world_size=world_size,
rank=rank,
init_method=init_method)
def print_separator(message):
torch.distributed.barrier()
filler_len = (78 - len(message)) // 2
filler = '-' * filler_len
string = '\n' + filler + ' {} '.format(message) + filler
if torch.distributed.get_rank() == 0:
print(string, flush=True)
torch.distributed.barrier()
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron global variables."""
import os
import sys
import time
import torch
from apex.transformer.tensor_parallel.microbatches import build_num_microbatches_calculator
from apex.transformer.tensor_parallel.tests.arguments import parse_args
_GLOBAL_ARGS = None
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = None
_GLOBAL_TOKENIZER = None
_GLOBAL_TENSORBOARD_WRITER = None
_GLOBAL_ADLR_AUTORESUME = None
_GLOBAL_TIMERS = None
def get_args():
"""Return arguments."""
_ensure_var_is_initialized(_GLOBAL_ARGS, 'args')
return _GLOBAL_ARGS
def get_num_microbatches():
return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get()
def get_current_global_batch_size():
return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get_current_global_batch_size()
def update_num_microbatches(consumed_samples, consistency_check=True):
_GLOBAL_NUM_MICROBATCHES_CALCULATOR.update(consumed_samples,
consistency_check)
# def get_tokenizer():
# """Return tokenizer."""
# _ensure_var_is_initialized(_GLOBAL_TOKENIZER, 'tokenizer')
# return _GLOBAL_TOKENIZER
def get_tensorboard_writer():
"""Return tensorboard writer. It can be None so no need
to check if it is initialized."""
return _GLOBAL_TENSORBOARD_WRITER
def get_adlr_autoresume():
"""ADLR autoresume object. It can be None so no need
to check if it is initialized."""
return _GLOBAL_ADLR_AUTORESUME
def get_timers():
"""Return timers."""
_ensure_var_is_initialized(_GLOBAL_TIMERS, 'timers')
return _GLOBAL_TIMERS
def set_global_variables(extra_args_provider=None, args_defaults={},
ignore_unknown_args=False):
"""Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers."""
args = _parse_args(extra_args_provider=extra_args_provider,
defaults=args_defaults,
ignore_unknown_args=ignore_unknown_args)
_build_num_microbatches_calculator(args)
# if args.vocab_file:
# _ = _build_tokenizer(args)
_set_tensorboard_writer(args)
_set_adlr_autoresume(args)
_set_timers()
def _parse_args(extra_args_provider=None, defaults={},
ignore_unknown_args=False):
"""Parse entire arguments."""
global _GLOBAL_ARGS
_ensure_var_is_not_initialized(_GLOBAL_ARGS, 'args')
_GLOBAL_ARGS = parse_args(extra_args_provider=extra_args_provider,
defaults=defaults,
ignore_unknown_args=ignore_unknown_args)
return _GLOBAL_ARGS
def _build_num_microbatches_calculator(args):
global _GLOBAL_NUM_MICROBATCHES_CALCULATOR
_ensure_var_is_not_initialized(_GLOBAL_NUM_MICROBATCHES_CALCULATOR,
'num microbatches calculator')
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = build_num_microbatches_calculator(
args)
# def _build_tokenizer(args):
# """Initialize tokenizer."""
# global _GLOBAL_TOKENIZER
# _ensure_var_is_not_initialized(_GLOBAL_TOKENIZER, 'tokenizer')
# _GLOBAL_TOKENIZER = build_tokenizer(args)
# return _GLOBAL_TOKENIZER
# def rebuild_tokenizer(args):
# global _GLOBAL_TOKENIZER
# _GLOBAL_TOKENIZER = None
# return _build_tokenizer(args)
def _set_tensorboard_writer(args):
"""Set tensorboard writer."""
global _GLOBAL_TENSORBOARD_WRITER
_ensure_var_is_not_initialized(_GLOBAL_TENSORBOARD_WRITER,
'tensorboard writer')
if hasattr(args, 'tensorboard_dir') and \
args.tensorboard_dir and args.rank == (args.world_size - 1):
try:
from torch.utils.tensorboard import SummaryWriter
print('> setting tensorboard ...')
_GLOBAL_TENSORBOARD_WRITER = SummaryWriter(
log_dir=args.tensorboard_dir,
max_queue=args.tensorboard_queue_size)
except ModuleNotFoundError:
print('WARNING: TensorBoard writing requested but is not '
'available (are you using PyTorch 1.1.0 or later?), '
'no TensorBoard logs will be written.', flush=True)
def _set_adlr_autoresume(args):
"""Initialize ADLR autoresume."""
global _GLOBAL_ADLR_AUTORESUME
_ensure_var_is_not_initialized(_GLOBAL_ADLR_AUTORESUME, 'adlr autoresume')
if args.adlr_autoresume:
if args.rank == 0:
print('enabling autoresume ...', flush=True)
sys.path.append(os.environ.get('SUBMIT_SCRIPTS', '.'))
try:
from userlib.auto_resume import AutoResume
except BaseException:
print('ADLR autoresume is not available, exiting ...')
sys.exit()
_GLOBAL_ADLR_AUTORESUME = AutoResume
def _set_timers():
"""Initialize timers."""
global _GLOBAL_TIMERS
_ensure_var_is_not_initialized(_GLOBAL_TIMERS, 'timers')
_GLOBAL_TIMERS = Timers()
def _ensure_var_is_initialized(var, name):
"""Make sure the input variable is not None."""
assert var is not None, '{} is not initialized.'.format(name)
def _ensure_var_is_not_initialized(var, name):
"""Make sure the input variable is not None."""
assert var is None, '{} is already initialized.'.format(name)
class _Timer:
"""Timer."""
def __init__(self, name):
self.name_ = name
self.elapsed_ = 0.0
self.started_ = False
self.start_time = time.time()
def start(self):
"""Start the timer."""
assert not self.started_, 'timer has already been started'
torch.cuda.synchronize()
self.start_time = time.time()
self.started_ = True
def stop(self):
"""Stop the timer."""
assert self.started_, 'timer is not started'
torch.cuda.synchronize()
self.elapsed_ += (time.time() - self.start_time)
self.started_ = False
def reset(self):
"""Reset timer."""
self.elapsed_ = 0.0
self.started_ = False
def elapsed(self, reset=True):
"""Calculate the elapsed time."""
started_ = self.started_
# If the timing in progress, end it first.
if self.started_:
self.stop()
# Get the elapsed time.
elapsed_ = self.elapsed_
# Reset the elapsed time
if reset:
self.reset()
# If timing was in progress, set it back.
if started_:
self.start()
return elapsed_
class Timers:
"""Group of timers."""
def __init__(self):
self.timers = {}
def __call__(self, name):
if name not in self.timers:
self.timers[name] = _Timer(name)
return self.timers[name]
def write(self, names, writer, iteration, normalizer=1.0, reset=False):
"""Write timers to a tensorboard writer"""
# currently when using add_scalars,
# torch.utils.add_scalars makes each timer its own run, which
# polutes the runs list, so we just add each as a scalar
assert normalizer > 0.0
for name in names:
value = self.timers[name].elapsed(reset=reset) / normalizer
writer.add_scalar(name + '-time', value, iteration)
def log(self, names, normalizer=1.0, reset=True):
"""Log a group of timers."""
assert normalizer > 0.0
string = 'time (ms)'
for name in names:
elapsed_time = self.timers[name].elapsed(
reset=reset) * 1000.0 / normalizer
string += ' | {}: {:.2f}'.format(name, elapsed_time)
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == (
torch.distributed.get_world_size() - 1):
print(string, flush=True)
else:
print(string, flush=True)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
def divide(numerator, denominator):
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility(numerator, denominator)
return numerator // denominator
def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False):
"""Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
"""
# Get the size and dimension.
last_dim = tensor.dim() - 1
last_dim_size = divide(tensor.size()[last_dim], num_partitions)
# Split.
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
# Note: torch.split does not create contiguous tensors by default.
if contiguous_split_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list
class VocabUtility:
"""Split the vocabulary into `world_size` chunks amd return the
first and last index of the vocabulary belonging to the `rank`
partition: Note that indecies in [fist, last)"""
@staticmethod
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size):
index_f = rank * per_partition_vocab_size
index_l = index_f + per_partition_vocab_size
return index_f, index_l
@staticmethod
def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
per_partition_vocab_size = divide(global_vocab_size, world_size)
return VocabUtility.vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size)
...@@ -130,12 +130,13 @@ std::vector<at::Tensor> layer_norm( ...@@ -130,12 +130,13 @@ std::vector<at::Tensor> layer_norm(
int n1,n2; int n1,n2;
check_args(input,normalized_shape,n1,n2); check_args(input,normalized_shape,n1,n2);
at::Tensor output = at::empty_like(input); at::Tensor output = at::empty_like(input);
at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input.scalar_type())); at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half || input.scalar_type()==at::ScalarType::BFloat16 ? at::ScalarType::Float : input.scalar_type()));
at::Tensor invvar = at::empty_like(mean); at::Tensor invvar = at::empty_like(mean);
cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2, cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2,
normalized_shape,NULL,NULL,epsilon); normalized_shape,NULL,NULL,epsilon);
return {output, mean, invvar}; return {output, mean, invvar};
} }
std::vector<at::Tensor> layer_norm_affine( std::vector<at::Tensor> layer_norm_affine(
at::Tensor input, at::Tensor input,
#ifdef VERSION_GE_1_1 #ifdef VERSION_GE_1_1
...@@ -152,13 +153,35 @@ std::vector<at::Tensor> layer_norm_affine( ...@@ -152,13 +153,35 @@ std::vector<at::Tensor> layer_norm_affine(
int n1,n2; int n1,n2;
check_args(input,normalized_shape,gamma,beta,n1,n2); check_args(input,normalized_shape,gamma,beta,n1,n2);
at::Tensor output = at::empty_like(input); at::Tensor output = at::empty_like(input);
at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input.scalar_type())); const auto stats_dtype = (input.scalar_type() == at::ScalarType::Half || input.scalar_type() == at::ScalarType::BFloat16) ? at::ScalarType::Float : input.scalar_type();
at::Tensor mean = at::empty({n1}, input.options().dtype(stats_dtype));
at::Tensor invvar = at::empty_like(mean); at::Tensor invvar = at::empty_like(mean);
cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2, cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2,
normalized_shape,&gamma,&beta,epsilon); normalized_shape,&gamma,&beta,epsilon);
return {output, mean, invvar}; return {output, mean, invvar};
} }
std::vector<at::Tensor> layer_norm_affine_mixed_dtypes(
at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor gamma,
at::Tensor beta,
double epsilon) {
CHECK_INPUT(input);
int n1, n2;
check_args(input, normalized_shape, n1, n2);
at::Tensor output = at::empty_like(input, gamma.options().dtype(gamma.scalar_type()));
at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type() == at::ScalarType::Half || input.scalar_type() == at::ScalarType::BFloat16 ? at::ScalarType::Float : input.scalar_type()));
at::Tensor invvar = at::empty_like(mean);
cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2,
normalized_shape, &gamma, &beta, epsilon);
return {output, mean, invvar};
}
void cuda_layer_norm_gradient( void cuda_layer_norm_gradient(
at::Tensor* dout, at::Tensor* dout,
at::Tensor* mean, at::Tensor* mean,
...@@ -202,6 +225,7 @@ at::Tensor layer_norm_gradient( ...@@ -202,6 +225,7 @@ at::Tensor layer_norm_gradient(
&grad_input,NULL,NULL); &grad_input,NULL,NULL);
return grad_input; return grad_input;
} }
std::vector<at::Tensor> layer_norm_gradient_affine( std::vector<at::Tensor> layer_norm_gradient_affine(
at::Tensor dout, at::Tensor dout,
at::Tensor mean, at::Tensor mean,
...@@ -237,5 +261,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -237,5 +261,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &layer_norm, "LayerNorm forward (CUDA)"); m.def("forward", &layer_norm, "LayerNorm forward (CUDA)");
m.def("backward_affine", &layer_norm_gradient_affine, "LayerNorm backward (CUDA)"); m.def("backward_affine", &layer_norm_gradient_affine, "LayerNorm backward (CUDA)");
m.def("backward", &layer_norm_gradient, "LayerNorm backward (CUDA)"); m.def("backward", &layer_norm_gradient, "LayerNorm backward (CUDA)");
m.def("forward_affine_mixed_dtypes", &layer_norm_affine_mixed_dtypes, "LayerNorm forward with mixed dtypes (CUDA) compatible with Megatron's implementation");
} }
...@@ -56,7 +56,7 @@ void cuWelfordMuSigma2( ...@@ -56,7 +56,7 @@ void cuWelfordMuSigma2(
const int i1, const int i1,
U& mu, U& mu,
U& sigma2, U& sigma2,
U* buf) U* buf)
{ {
// Assumptions: // Assumptions:
// 1) blockDim.x == warpSize // 1) blockDim.x == warpSize
...@@ -140,7 +140,7 @@ void cuWelfordMuSigma2( ...@@ -140,7 +140,7 @@ void cuWelfordMuSigma2(
const int i1, const int i1,
float& mu, float& mu,
float& sigma2, float& sigma2,
float* buf) float* buf)
{ {
// Assumptions: // Assumptions:
// 1) blockDim.x == warpSize // 1) blockDim.x == warpSize
...@@ -173,7 +173,7 @@ void cuWelfordMuSigma2( ...@@ -173,7 +173,7 @@ void cuWelfordMuSigma2(
for (int k = 0; k < 8; k+=2) { for (int k = 0; k < 8; k+=2) {
float2 curr = __half22float2(*((__half2*)(lvals+l+k))); float2 curr = __half22float2(*((__half2*)(lvals+l+k)));
cuWelfordOnlineSum(curr.x,mu,sigma2,count); cuWelfordOnlineSum(curr.x,mu,sigma2,count);
cuWelfordOnlineSum(curr.y,mu,sigma2,count); cuWelfordOnlineSum(curr.y,mu,sigma2,count);
} }
} }
for (; l < n2; ++l) { for (; l < n2; ++l) {
...@@ -276,18 +276,18 @@ struct SharedMemory <double> ...@@ -276,18 +276,18 @@ struct SharedMemory <double>
}; };
} }
template<typename T, typename U> __global__ template<typename T, typename U, typename V> __device__
void cuApplyLayerNorm( void cuApplyLayerNorm_(
T* __restrict__ output_vals, V* __restrict__ output_vals,
U* __restrict__ mean, U* __restrict__ mean,
U* __restrict__ invvar, U* __restrict__ invvar,
const T* __restrict__ vals, const T* __restrict__ vals,
const int n1, const int n1,
const int n2, const int n2,
const U epsilon, const U epsilon,
const T* __restrict__ gamma, const V* __restrict__ gamma,
const T* __restrict__ beta const V* __restrict__ beta
) )
{ {
// Assumptions: // Assumptions:
// 1) blockDim.x == warpSize // 1) blockDim.x == warpSize
...@@ -299,19 +299,19 @@ void cuApplyLayerNorm( ...@@ -299,19 +299,19 @@ void cuApplyLayerNorm(
U mu,sigma2; U mu,sigma2;
cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf); cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf);
const T* lvals = vals + i1*n2; const T* lvals = vals + i1*n2;
T* ovals = output_vals + i1*n2; V* ovals = output_vals + i1*n2;
U c_invvar = rsqrt(sigma2 + epsilon); U c_invvar = rsqrt(sigma2 + epsilon);
const int numx = blockDim.x * blockDim.y; const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x; const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
if (gamma != NULL && beta != NULL) { if (gamma != NULL && beta != NULL) {
for (int i = thrx; i < n2; i+=numx) { for (int i = thrx; i < n2; i+=numx) {
U curr = static_cast<U>(lvals[i]); U curr = static_cast<U>(lvals[i]);
ovals[i] = gamma[i] * static_cast<T>(c_invvar * (curr - mu)) + beta[i]; ovals[i] = gamma[i] * static_cast<V>(c_invvar * (curr - mu)) + beta[i];
} }
} else { } else {
for (int i = thrx; i < n2; i+=numx) { for (int i = thrx; i < n2; i+=numx) {
U curr = static_cast<U>(lvals[i]); U curr = static_cast<U>(lvals[i]);
ovals[i] = static_cast<T>(c_invvar * (curr - mu)); ovals[i] = static_cast<V>(c_invvar * (curr - mu));
} }
} }
if (threadIdx.x == 0 && threadIdx.y == 0) { if (threadIdx.x == 0 && threadIdx.y == 0) {
...@@ -321,7 +321,24 @@ void cuApplyLayerNorm( ...@@ -321,7 +321,24 @@ void cuApplyLayerNorm(
} }
} }
template<typename T, typename U> __device__ template<typename T, typename U, typename V=T> __global__
void cuApplyLayerNorm(
V* __restrict__ output_vals,
U* __restrict__ mean,
U* __restrict__ invvar,
const T* __restrict__ vals,
const int n1,
const int n2,
const U epsilon,
const V* __restrict__ gamma,
const V* __restrict__ beta
)
{
cuApplyLayerNorm_<T, U, V>(output_vals, mean, invvar, vals, n1, n2, epsilon, gamma, beta);
}
template<typename T, typename U, typename V> __device__
void cuLoadWriteStridedInputs( void cuLoadWriteStridedInputs(
const int i1_block, const int i1_block,
const int thr_load_row_off, const int thr_load_row_off,
...@@ -331,7 +348,7 @@ void cuLoadWriteStridedInputs( ...@@ -331,7 +348,7 @@ void cuLoadWriteStridedInputs(
U* warp_buf1, U* warp_buf1,
U* warp_buf2, U* warp_buf2,
const T* input, const T* input,
const T* dout, const V* dout,
const int i1_end, const int i1_end,
const int n2, const int n2,
const U* __restrict__ mean, const U* __restrict__ mean,
...@@ -348,9 +365,9 @@ void cuLoadWriteStridedInputs( ...@@ -348,9 +365,9 @@ void cuLoadWriteStridedInputs(
int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
if (i2<n2) { if (i2<n2) {
U curr_input = static_cast<U>(input[load_idx]); U curr_input = static_cast<U>(input[load_idx]);
U curr_dout = static_cast<U>(dout[load_idx]); U curr_dout = static_cast<U>(dout[load_idx]);
warp_buf1[write_idx] = curr_dout; warp_buf1[write_idx] = curr_dout;
warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar; warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar;
} else { } else {
warp_buf1[write_idx] = U(0); warp_buf1[write_idx] = U(0);
warp_buf2[write_idx] = U(0); warp_buf2[write_idx] = U(0);
...@@ -365,7 +382,7 @@ void cuLoadWriteStridedInputs( ...@@ -365,7 +382,7 @@ void cuLoadWriteStridedInputs(
} }
} }
template<typename T, typename U> __device__ template<typename T, typename U, typename V> __device__
void cuLoadAddStridedInputs( void cuLoadAddStridedInputs(
const int i1_block, const int i1_block,
const int thr_load_row_off, const int thr_load_row_off,
...@@ -375,7 +392,7 @@ void cuLoadAddStridedInputs( ...@@ -375,7 +392,7 @@ void cuLoadAddStridedInputs(
U* warp_buf1, U* warp_buf1,
U* warp_buf2, U* warp_buf2,
const T* input, const T* input,
const T* dout, const V* dout,
const int i1_end, const int i1_end,
const int n2, const int n2,
const U* __restrict__ mean, const U* __restrict__ mean,
...@@ -392,17 +409,17 @@ void cuLoadAddStridedInputs( ...@@ -392,17 +409,17 @@ void cuLoadAddStridedInputs(
int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
if (i2<n2) { if (i2<n2) {
U curr_input = static_cast<U>(input[load_idx]); U curr_input = static_cast<U>(input[load_idx]);
U curr_dout = static_cast<U>(dout[load_idx]); U curr_dout = static_cast<U>(dout[load_idx]);
warp_buf1[write_idx] += curr_dout; warp_buf1[write_idx] += curr_dout;
warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar; warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar;
} }
} }
} }
} }
template<typename T, typename U> __global__ template<typename T, typename U, typename V> __global__
void cuComputePartGradGammaBeta( void cuComputePartGradGammaBeta(
const T* __restrict__ dout, const V* __restrict__ dout,
const T* __restrict__ input, const T* __restrict__ input,
const int n1, const int n1,
const int n2, const int n2,
...@@ -449,11 +466,11 @@ void cuComputePartGradGammaBeta( ...@@ -449,11 +466,11 @@ void cuComputePartGradGammaBeta(
for (int offset = blockDim.y/2; offset > 1; offset /= 2) { for (int offset = blockDim.y/2; offset > 1; offset /= 2) {
if (threadIdx.y < offset) { if (threadIdx.y < offset) {
int row1 = threadIdx.y; int row1 = threadIdx.y;
int row2 = threadIdx.y + offset; int row2 = threadIdx.y + offset;
int idx1 = row1*row_stride + threadIdx.x; int idx1 = row1*row_stride + threadIdx.x;
int idx2 = row2*row_stride + threadIdx.x; int idx2 = row2*row_stride + threadIdx.x;
warp_buf1[idx1] += warp_buf1[idx2]; warp_buf1[idx1] += warp_buf1[idx2];
warp_buf2[idx1] += warp_buf2[idx2]; warp_buf2[idx1] += warp_buf2[idx2];
} }
__syncthreads(); __syncthreads();
} }
...@@ -468,19 +485,19 @@ void cuComputePartGradGammaBeta( ...@@ -468,19 +485,19 @@ void cuComputePartGradGammaBeta(
} }
} }
template<typename T, typename U> __global__ template<typename U, typename V> __global__
void cuComputeGradGammaBeta( void cuComputeGradGammaBeta(
const U* part_grad_gamma, const U* part_grad_gamma,
const U* part_grad_beta, const U* part_grad_beta,
const int part_size, const int part_size,
const int n1, const int n1,
const int n2, const int n2,
T* grad_gamma, V* grad_gamma,
T* grad_beta) V* grad_beta)
{ {
// sum partial gradients for gamma and beta // sum partial gradients for gamma and beta
SharedMemory<U> shared; SharedMemory<U> shared;
U* buf = shared.getPointer(); U* buf = shared.getPointer();
int i2 = blockIdx.x * blockDim.x + threadIdx.x; int i2 = blockIdx.x * blockDim.x + threadIdx.x;
if (i2 < n2) { if (i2 < n2) {
// each warp does sequential reductions until reduced part_size is num_warps // each warp does sequential reductions until reduced part_size is num_warps
...@@ -519,16 +536,16 @@ void cuComputeGradGammaBeta( ...@@ -519,16 +536,16 @@ void cuComputeGradGammaBeta(
} }
} }
template<typename T, typename U> __global__ template<typename T, typename U, typename V> __global__
void cuComputeGradInput( void cuComputeGradInput(
const T* __restrict__ dout, const V* __restrict__ dout,
const T* __restrict__ input, const T* __restrict__ input,
const int n1, const int n1,
const int n2, const int n2,
const U* __restrict__ mean, const U* __restrict__ mean,
const U* __restrict__ invvar, const U* __restrict__ invvar,
U epsilon, U epsilon,
const T* gamma, const V* gamma,
T* grad_input) T* grad_input)
{ {
for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) { for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
...@@ -537,7 +554,7 @@ void cuComputeGradInput( ...@@ -537,7 +554,7 @@ void cuComputeGradInput(
const U c_mean = mean[i1]; const U c_mean = mean[i1];
const U c_invvar = invvar[i1]; const U c_invvar = invvar[i1];
const T* k_input = input + i1*n2; const T* k_input = input + i1*n2;
const T* k_dout = dout + i1*n2; const V* k_dout = dout + i1*n2;
const int numx = blockDim.x * blockDim.y; const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x; const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
if (gamma != NULL) { if (gamma != NULL) {
...@@ -581,7 +598,7 @@ void cuComputeGradInput( ...@@ -581,7 +598,7 @@ void cuComputeGradInput(
// inter-warp reductions // inter-warp reductions
if (blockDim.y > 1) { if (blockDim.y > 1) {
SharedMemory<U> shared; SharedMemory<U> shared;
U* buf = shared.getPointer(); U* buf = shared.getPointer();
for (int offset = blockDim.y/2; offset > 0; offset /= 2) { for (int offset = blockDim.y/2; offset > 0; offset /= 2) {
// upper half of warps write to shared // upper half of warps write to shared
if (threadIdx.y >= offset && threadIdx.y < 2*offset) { if (threadIdx.y >= offset && threadIdx.y < 2*offset) {
...@@ -606,7 +623,7 @@ void cuComputeGradInput( ...@@ -606,7 +623,7 @@ void cuComputeGradInput(
if (threadIdx.y !=0) { if (threadIdx.y !=0) {
sum_loss1 = buf[2*threadIdx.x]; sum_loss1 = buf[2*threadIdx.x];
sum_loss2 = buf[2*threadIdx.x+1]; sum_loss2 = buf[2*threadIdx.x+1];
} }
} }
// all threads now have the two sums over l // all threads now have the two sums over l
U fH = (U)n2; U fH = (U)n2;
...@@ -636,35 +653,29 @@ void cuComputeGradInput( ...@@ -636,35 +653,29 @@ void cuComputeGradInput(
} }
} }
template<typename T, typename U> template<typename T, typename U, typename V=T>
void HostApplyLayerNorm( void HostApplyLayerNorm(
T* output, V* output,
U* mean, U* mean,
U* invvar, U* invvar,
const T* input, const T* input,
int n1, int n1,
int n2, int n2,
double epsilon, double epsilon,
const T* gamma, const V* gamma,
const T* beta const V* beta
) )
{ {
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
const dim3 threads(32,4,1); const dim3 threads(32,4,1);
const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1); const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);
int nshared = int nshared =
threads.y > 1 ? threads.y > 1 ?
threads.y*sizeof(U)+(threads.y/2)*sizeof(U) : threads.y*sizeof(U)+(threads.y/2)*sizeof(U) :
0; 0;
cuApplyLayerNorm<<<blocks, threads, nshared, stream>>>( cuApplyLayerNorm<<<blocks, threads, nshared, stream>>>(
output, output, mean, invvar, input, n1, n2, U(epsilon), gamma, beta);
mean,
invvar,
input,
n1,n2,
U(epsilon),
gamma,beta);
} }
void cuda_layer_norm( void cuda_layer_norm(
...@@ -684,34 +695,35 @@ void cuda_layer_norm( ...@@ -684,34 +695,35 @@ void cuda_layer_norm(
double epsilon) double epsilon)
{ {
using namespace at; using namespace at;
DISPATCH_DOUBLE_FLOAT_AND_HALF(input->scalar_type(), 0, "layer_norm_cuda_kernel", DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
using accscalar_t = at::acc_type<scalar_t_0, true>; input->scalar_type(), output->scalar_type(), "layer_norm_cuda_kernel",
HostApplyLayerNorm( using accscalar_t = at::acc_type<scalar_t_in, true>;
output->DATA_PTR<scalar_t_0>(), HostApplyLayerNorm<scalar_t_in, accscalar_t, scalar_t_out>(
mean->DATA_PTR<accscalar_t>(), output->DATA_PTR<scalar_t_out>(),
invvar->DATA_PTR<accscalar_t>(), mean->DATA_PTR<accscalar_t>(),
input->DATA_PTR<scalar_t_0>(), invvar->DATA_PTR<accscalar_t>(),
n1,n2, input->DATA_PTR<scalar_t_in>(),
epsilon, n1,n2,
gamma != NULL ? gamma->DATA_PTR<scalar_t_0>() : NULL, epsilon,
beta != NULL ? beta->DATA_PTR<scalar_t_0>() : NULL); gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
beta != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL);
) )
} }
template<typename T, typename U> template<typename T, typename U=float, typename V=T>
void HostLayerNormGradient( void HostLayerNormGradient(
const T* dout, const V* dout,
const U* mean, const U* mean,
const U* invvar, const U* invvar,
at::Tensor* input, at::Tensor* input,
int n1, int n1,
int n2, int n2,
const T* gamma, const V* gamma,
const T* beta, const V* beta,
double epsilon, double epsilon,
T* grad_input, T* grad_input,
T* grad_gamma, V* grad_gamma,
T* grad_beta V* grad_beta
) )
{ {
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
...@@ -724,7 +736,13 @@ void HostLayerNormGradient( ...@@ -724,7 +736,13 @@ void HostLayerNormGradient(
const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1); const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1);
const int nshared2_b = threads2.x * threads2.y * sizeof(U); const int nshared2_b = threads2.x * threads2.y * sizeof(U);
const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
at::Tensor part_grad_gamma = at::empty({part_size,n2}, input->options().dtype(input->scalar_type()==at::ScalarType::Half ? at::ScalarType::Float : input->scalar_type())); // note (mkozuki): I can hard code part_grad_gamma's dtype as float given that
// the `cuda_layer_norm_gradient` doesn't support double.
const auto part_grad_dtype =
(input->scalar_type() == at::ScalarType::Half || input->scalar_type() == at::ScalarType::BFloat16) ?
at::ScalarType::Float :
input->scalar_type();
at::Tensor part_grad_gamma = at::empty({part_size,n2}, input->options().dtype(part_grad_dtype));
at::Tensor part_grad_beta = at::empty_like(part_grad_gamma); at::Tensor part_grad_beta = at::empty_like(part_grad_gamma);
cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>( cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>(
dout, dout,
...@@ -787,21 +805,23 @@ void cuda_layer_norm_gradient( ...@@ -787,21 +805,23 @@ void cuda_layer_norm_gradient(
at::Tensor* grad_beta) at::Tensor* grad_beta)
{ {
using namespace at; using namespace at;
DISPATCH_FLOAT_AND_HALF(input->scalar_type(), 0, "cuComputeGradInput", // we can do away with `accscalar_t` as there're only three dtypes: fp32, fp16, bf16
using accscalar_t = at::acc_type<scalar_t_0, true>; DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
HostLayerNormGradient( input->scalar_type(), gamma == NULL ? input->scalar_type() : gamma->scalar_type(), "cuComputeGradInput",
dout->DATA_PTR<scalar_t_0>(), using accscalar_t = at::acc_type<scalar_t_in, true>;
mean->DATA_PTR<accscalar_t>(), HostLayerNormGradient(
invvar->DATA_PTR<accscalar_t>(), dout->DATA_PTR<scalar_t_out>(),
input, mean->DATA_PTR<accscalar_t>(),
n1,n2, invvar->DATA_PTR<accscalar_t>(),
input,
n1,n2,
// TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta // TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
// if gamma Tensor is NULL on input. // if gamma Tensor is NULL on input.
gamma != NULL ? gamma->DATA_PTR<scalar_t_0>() : NULL, gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
gamma != NULL ? beta->DATA_PTR<scalar_t_0>() : NULL, gamma != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL,
epsilon, epsilon,
grad_input->DATA_PTR<scalar_t_0>(), grad_input->DATA_PTR<scalar_t_in>(),
gamma != NULL ? grad_gamma->DATA_PTR<scalar_t_0>() : NULL, gamma != NULL ? grad_gamma->DATA_PTR<scalar_t_out>() : NULL,
gamma != NULL ? grad_beta->DATA_PTR<scalar_t_0>() : NULL); gamma != NULL ? grad_beta->DATA_PTR<scalar_t_out>() : NULL);
) )
} }
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace multihead_attn {
namespace fused_softmax {
namespace scaled_masked_softmax {
torch::Tensor fwd_cuda(
torch::Tensor const& input,
torch::Tensor const& mask,
float scale_factor);
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor);
int get_batch_per_block_cuda(
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads);
torch::Tensor fwd(
torch::Tensor const& input,
torch::Tensor const& mask,
float scale_factor) {
AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM(mask.dim() == 4, "expected 4D tensor");
return fwd_cuda(input, mask, scale_factor);
}
torch::Tensor bwd(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor) {
AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return bwd_cuda(output_grads, softmax_results, scale_factor);
}
int get_batch_per_block(
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads) {
return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads);
}
} // end namespace scaled_masked_softmax
} // end namespace fused_softmax
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward",
&multihead_attn::fused_softmax::scaled_masked_softmax::fwd,
"Self Multihead Attention scaled, time masked softmax -- Forward.");
m.def("backward",
&multihead_attn::fused_softmax::scaled_masked_softmax::bwd,
"Self Multihead Attention scaled, time masked softmax -- Backward.");
m.def("get_batch_per_block",
&multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block,
"Return Batch per block size."
);
}
This diff is collapsed.
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "scaled_masked_softmax.h"
#include "type_shim.h"
namespace multihead_attn {
namespace fused_softmax {
namespace scaled_masked_softmax {
int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){
return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads);
}
torch::Tensor fwd_cuda(
torch::Tensor const& input,
torch::Tensor const& mask,
float scale_factor)
{
// input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = input.size(0);
const int pad_batches = mask.size(0);
const int attn_heads = input.size(1);
const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3);
TORCH_INTERNAL_ASSERT(key_seq_len <= 2048);
TORCH_INTERNAL_ASSERT(query_seq_len > 1);
TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches);
TORCH_INTERNAL_ASSERT(mask.size(1) == 1);
TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len);
TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len);
// Output
auto act_options = input.options().requires_grad(false);
torch::Tensor softmax_results =
torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
// Softmax Intermediate Result Ptr
void* input_ptr = static_cast<void*>(input.data_ptr());
void* mask_ptr = static_cast<void*>(mask.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
DISPATCH_HALF_AND_BFLOAT(
input.scalar_type(),
"dispatch_scaled_masked_softmax_forward",
dispatch_scaled_masked_softmax_forward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(softmax_results_ptr),
reinterpret_cast<const scalar_t*>(input_ptr),
reinterpret_cast<const uint8_t*>(mask_ptr),
scale_factor,
query_seq_len,
key_seq_len,
batches,
attn_heads,
pad_batches);
);
return softmax_results;
}
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads_,
torch::Tensor const& softmax_results_,
float scale_factor) {
auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous();
//output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = output_grads.size(0);
const int attn_heads = output_grads.size(1);
const int query_seq_len = output_grads.size(2);
const int key_seq_len = output_grads.size(3);
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
//Softmax Grad
DISPATCH_HALF_AND_BFLOAT(
output_grads_.scalar_type(),
"dispatch_scaled_masked_softmax_backward",
dispatch_scaled_masked_softmax_backward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
scale_factor,
query_seq_len,
key_seq_len,
batches,
attn_heads);
);
//backward pass is completely in-place
return output_grads;
}
}
}
}
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace multihead_attn {
namespace fused_softmax {
namespace scaled_upper_triang_masked_softmax {
torch::Tensor fwd_cuda(
torch::Tensor const& input,
float scale_factor);
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor);
torch::Tensor fwd(torch::Tensor const& input, float scale_factor) {
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return fwd_cuda(input, scale_factor);
}
torch::Tensor bwd(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return bwd_cuda(output_grads, softmax_results, scale_factor);
}
} // end namespace scaled_upper_triang_masked_softmax
} // end namespace fused_softmax
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward",
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd,
"Self Multihead Attention scaled, time masked softmax -- Forward.");
m.def("backward",
&multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd,
"Self Multihead Attention scaled, time masked softmax -- Backward.");
}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -206,6 +206,30 @@ if "--cuda_ext" in sys.argv: ...@@ -206,6 +206,30 @@ if "--cuda_ext" in sys.argv:
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros})) 'nvcc':['-O3'] + version_dependent_macros}))
ext_modules.append(
CUDAExtension(name='scaled_upper_triang_masked_softmax_cuda',
sources=['csrc/megatron/scaled_upper_triang_masked_softmax.cpp',
'csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda'] + version_dependent_macros}))
ext_modules.append(
CUDAExtension(name='scaled_masked_softmax_cuda',
sources=['csrc/megatron/scaled_masked_softmax.cpp',
'csrc/megatron/scaled_masked_softmax_cuda.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda'] + version_dependent_macros}))
if "--bnp" in sys.argv: if "--bnp" in sys.argv:
sys.argv.remove("--bnp") sys.argv.remove("--bnp")
...@@ -495,6 +519,7 @@ if "--fast_bottleneck" in sys.argv: ...@@ -495,6 +519,7 @@ if "--fast_bottleneck" in sys.argv:
include_dirs=[os.path.join(this_dir, 'apex/contrib/csrc/cudnn-frontend/include')], include_dirs=[os.path.join(this_dir, 'apex/contrib/csrc/cudnn-frontend/include')],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag})) extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag}))
setup( setup(
name='apex', name='apex',
version='0.1', version='0.1',
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment