Unverified Commit cc92a4b4 authored by Jithun Nair's avatar Jithun Nair Committed by GitHub
Browse files

Merge pull request #55 from ROCmSoftwarePlatform/IFU-master-2021-10-15

IFU-2021-10-15 (+ remove redundant defines + C10_CUDA_CHECK)
parents 1e0f9bc6 fec3141c
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from ..parallel_state import get_tensor_model_parallel_group
from ..parallel_state import get_tensor_model_parallel_rank
from ..parallel_state import get_tensor_model_parallel_world_size
from .utils import VocabUtility
class _VocabParallelCrossEntropy(torch.autograd.Function):
@staticmethod
def forward(ctx, vocab_parallel_logits, target):
# Maximum value along vocab dimension across all GPUs.
logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
torch.distributed.all_reduce(
logits_max, op=torch.distributed.ReduceOp.MAX, group=get_tensor_model_parallel_group()
)
# Subtract the maximum value.
vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
# Get the partition's vocab indecies
get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size
partition_vocab_size = vocab_parallel_logits.size()[-1]
rank = get_tensor_model_parallel_rank()
world_size = get_tensor_model_parallel_world_size()
vocab_start_index, vocab_end_index = get_vocab_range(partition_vocab_size, rank, world_size)
# Create a mask of valid vocab ids (1 means it needs to be masked).
target_mask = (target < vocab_start_index) | (target >= vocab_end_index)
masked_target = target.clone() - vocab_start_index
masked_target[target_mask] = 0
# Get predicted-logits = logits[target].
# For Simplicity, we convert logits to a 2-D tensor with size
# [*, partition-vocab-size] and target to a 1-D tensor of size [*].
logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size)
masked_target_1d = masked_target.view(-1)
arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device)
predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
predicted_logits_1d = predicted_logits_1d.clone().contiguous()
predicted_logits = predicted_logits_1d.view_as(target)
predicted_logits[target_mask] = 0.0
# All reduce is needed to get the chunks from other GPUs.
torch.distributed.all_reduce(
predicted_logits, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group()
)
# Sum of exponential of logits along vocab dimension across all GPUs.
exp_logits = vocab_parallel_logits
torch.exp(vocab_parallel_logits, out=exp_logits)
sum_exp_logits = exp_logits.sum(dim=-1)
torch.distributed.all_reduce(
sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group()
)
# Loss = log(sum(exp(logits))) - predicted-logit.
loss = torch.log(sum_exp_logits) - predicted_logits
# Store softmax, target-mask and masked-target for backward pass.
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
return loss
@staticmethod
def backward(ctx, grad_output):
# Retreive tensors from the forward path.
softmax, target_mask, masked_target_1d = ctx.saved_tensors
# All the inputs have softmax as thier gradient.
grad_input = softmax
# For simplicity, work with the 2D gradient.
partition_vocab_size = softmax.size()[-1]
grad_2d = grad_input.view(-1, partition_vocab_size)
# Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
grad_2d[arange_1d, masked_target_1d] -= 1.0 - target_mask.view(-1).float()
# Finally elementwise multiplication with the output gradients.
grad_input.mul_(grad_output.unsqueeze(dim=-1))
return grad_input, None
def vocab_parallel_cross_entropy(vocab_parallel_logits, target):
"""Helper function for the cross entropy."""
return _VocabParallelCrossEntropy.apply(torch.clone(vocab_parallel_logits), target)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from ..parallel_state import get_tensor_model_parallel_group
from ..parallel_state import get_tensor_model_parallel_rank
from ..parallel_state import get_tensor_model_parallel_src_rank
_MAX_DATA_DIM = 5
def _check_data_types(keys, data, target_dtype):
"""Check that all the keys have the same target data type."""
for key in keys:
assert data[key].dtype == target_dtype, "{} has data type {} which " "is different than {}".format(
key, data[key].dtype, target_dtype
)
def _build_key_size_numel_dictionaries(keys, data):
"""Build the size on rank 0 and broadcast."""
max_dim = _MAX_DATA_DIM
sizes = [0 for _ in range(max_dim) for _ in keys]
# Pack the sizes on rank zero.
if get_tensor_model_parallel_rank() == 0:
offset = 0
for key in keys:
assert data[key].dim() < max_dim, "you should increase MAX_DATA_DIM"
size = data[key].size()
for i, s in enumerate(size):
sizes[i + offset] = s
offset += max_dim
# Move to GPU and broadcast.
sizes_cuda = torch.cuda.LongTensor(sizes)
torch.distributed.broadcast(
sizes_cuda, get_tensor_model_parallel_src_rank(), group=get_tensor_model_parallel_group(),
)
# Move back to cpu and unpack.
sizes_cpu = sizes_cuda.cpu()
key_size = {}
key_numel = {}
total_numel = 0
offset = 0
for key in keys:
i = 0
size = []
numel = 1
while sizes_cpu[offset + i] > 0:
this_size = sizes_cpu[offset + i]
size.append(this_size)
numel *= this_size
i += 1
key_size[key] = size
key_numel[key] = numel
total_numel += numel
offset += max_dim
return key_size, key_numel, total_numel
def broadcast_data(keys, data, datatype):
"""Broadcast data from rank zero of each model parallel group to the
members of the same model parallel group.
Arguments:
keys: list of keys in the data disctionary to be broadcasted
data: data dictionary of string keys and cpu tensor values.
datatype: torch data type of all tensors in data associated
with keys.
"""
# Build (key, size) and (key, number of elements) dictionaries along
# with the total number of elements on all ranks.
key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys, data)
# Pack on rank zero.
if get_tensor_model_parallel_rank() == 0:
# Check that all keys have the same data type.
_check_data_types(keys, data, datatype)
# Flatten the data associated with the keys
flatten_data = torch.cat([data[key].contiguous().view(-1) for key in keys], dim=0).cuda()
else:
flatten_data = torch.empty(total_numel, device=torch.cuda.current_device(), dtype=datatype)
# Broadcast
torch.distributed.broadcast(
flatten_data, get_tensor_model_parallel_src_rank(), group=get_tensor_model_parallel_group(),
)
# Unpack
output = {}
offset = 0
for key in keys:
size = key_size[key]
numel = key_numel[key]
output[key] = flatten_data.narrow(0, offset, numel).view(size)
offset += numel
return output
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch
import torch
import torch.nn.functional as F
import torch.nn.init as init
from torch.nn.parameter import Parameter
from ..._autocast_utils import _cast_if_autocast_enabled
from ..parallel_state import get_tensor_model_parallel_group
from ..parallel_state import get_tensor_model_parallel_rank
from ..parallel_state import get_tensor_model_parallel_world_size
from .mappings import copy_to_tensor_model_parallel_region
from .mappings import gather_from_tensor_model_parallel_region
from .mappings import reduce_from_tensor_model_parallel_region
from .mappings import scatter_to_tensor_model_parallel_region
from .random import get_cuda_rng_tracker
from .utils import divide
from .utils import VocabUtility
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
"tensor_model_parallel": False,
"partition_dim": -1,
"partition_stride": 1,
}
def param_is_not_tensor_parallel_duplicate(param):
return (hasattr(param, "tensor_model_parallel") and param.tensor_model_parallel) or (
get_tensor_model_parallel_rank() == 0
)
def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride):
# Make sure the attributes are not set.
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
assert not hasattr(tensor, attribute)
# Set the attributes.
setattr(tensor, "tensor_model_parallel", is_parallel)
setattr(tensor, "partition_dim", dim)
setattr(tensor, "partition_stride", stride)
def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor):
def maybe_set(attribute, value):
if not hasattr(tensor, attribute):
setattr(tensor, attribute, value)
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
maybe_set(attribute, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS[attribute])
def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor):
def maybe_copy(attribute):
if hasattr(source_tensor, attribute):
setattr(destination_tensor, attribute, getattr(source_tensor, attribute))
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
maybe_copy(attribute)
def _initialize_affine_weight_gpu(weight, init_method, partition_dim, stride=1):
"""Initialize affine weight for model parallel on GPU."""
set_tensor_model_parallel_attributes(tensor=weight, is_parallel=True, dim=partition_dim, stride=stride)
with get_cuda_rng_tracker().fork():
init_method(weight)
# TODO (mkozuki): Re-consider removing params_dtype from arguments to make this
# more parallel with _initialize_affine_weight_gpu
def _initialize_affine_weight_cpu(
weight,
output_size,
input_size,
per_partition_size,
partition_dim,
init_method,
stride=1,
return_master_weight=False,
*,
params_dtype=torch.float32,
):
"""Initialize affine weight for model parallel.
Build the master weight on all processes and scatter
the relevant chunk."""
set_tensor_model_parallel_attributes(tensor=weight, is_parallel=True, dim=partition_dim, stride=stride)
# Initialize master weight
master_weight = torch.empty(output_size, input_size, dtype=torch.float, requires_grad=False)
init_method(master_weight)
master_weight = master_weight.to(dtype=params_dtype)
# Split and copy
per_partition_per_stride_size = divide(per_partition_size, stride)
weight_list = torch.split(master_weight, per_partition_per_stride_size, dim=partition_dim)
rank = get_tensor_model_parallel_rank()
world_size = get_tensor_model_parallel_world_size()
my_weight_list = weight_list[rank::world_size]
with torch.no_grad():
torch.cat(my_weight_list, dim=partition_dim, out=weight)
if return_master_weight:
return master_weight
return None
class VocabParallelEmbedding(torch.nn.Module):
"""Embedding parallelized in the vocabulary dimension.
This is mainly adapted from torch.nn.Embedding and all the default
values are kept.
Arguments:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
init_method: method to initialize weights.
"""
def __init__(
self, num_embeddings, embedding_dim, init_method=init.xavier_normal_, *, params_dtype=torch.float32, use_cpu_initialization=False,
):
super(VocabParallelEmbedding, self).__init__()
# Keep the input dimensions.
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
# Set the detauls for compatibility.
self.padding_idx = None
self.max_norm = None
self.norm_type = 2.0
self.scale_grad_by_freq = False
self.sparse = False
self._weight = None
self.tensor_model_parallel_size = get_tensor_model_parallel_world_size()
# Divide the weight matrix along the vocaburaly dimension.
self.vocab_start_index, self.vocab_end_index = VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings, get_tensor_model_parallel_rank(), self.tensor_model_parallel_size
)
self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index
# Allocate weights and initialize.
if use_cpu_initialization:
self.weight = Parameter(
torch.empty(self.num_embeddings_per_partition, self.embedding_dim, dtype=params_dtype)
)
_initialize_affine_weight_cpu(
self.weight, self.num_embeddings, self.embedding_dim, self.num_embeddings_per_partition, 0, init_method,
params_dtype=params_dtype,
)
else:
self.weight = Parameter(
torch.empty(
self.num_embeddings_per_partition,
self.embedding_dim,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
_initialize_affine_weight_gpu(self.weight, init_method, partition_dim=0, stride=1)
def forward(self, input_):
if self.tensor_model_parallel_size > 1:
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
else:
masked_input = input_
# Get the embeddings.
output_parallel = F.embedding(
masked_input,
self.weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)
# Mask the output embedding.
if self.tensor_model_parallel_size > 1:
output_parallel[input_mask, :] = 0.0
# Reduce across all the model parallel GPUs.
output = reduce_from_tensor_model_parallel_region(output_parallel)
return output
class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function):
"""
Column-parallel linear layer execution with asynchronous all-reduce
execution in backprop.
"""
@staticmethod
def forward(ctx, input, weight, bias):
ctx.save_for_backward(input, weight)
ctx.use_bias = bias is not None
output = torch.matmul(input, weight.t())
if bias is not None:
output = output + bias
return output
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
use_bias = ctx.use_bias
grad_input = grad_output.matmul(weight)
# Asynchronous all-reduce
handle = torch.distributed.all_reduce(
grad_input, group=get_tensor_model_parallel_group(), async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
grad_weight = grad_output.t().matmul(input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
handle.wait()
return grad_input, grad_weight, grad_bias
def column_parallel_linear(input, weight, bias):
args = _cast_if_autocast_enabled(input, weight, bias)
with torch.cuda.amp.autocast(enabled=False):
return ColumnParallelLinearWithAsyncAllreduce.apply(*args)
class ColumnParallelLinear(torch.nn.Module):
"""Linear layer with column parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its second dimension as A = [A_1, ..., A_p].
Arguments:
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
bias: If true, add bias
gather_output: If true, call all-gether on output and make Y avaiable
to all GPUs, otherwise, every GPU will have its output
which is Y_i = XA_i
init_method: method to initialize weights. Note that bias is always set
to zero.
stride: For the strided linear layers.
keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights
used for initialization.
skip_bias_add: This was added to enable performance optimations where bias
can be fused with other elementwise operations. we skip
adding bias but instead return it.
"""
def __init__(
self,
input_size,
output_size,
bias=True,
gather_output=True,
init_method=init.xavier_normal_,
stride=1,
keep_master_weight_for_test=False,
skip_bias_add=False,
*,
no_async_tensor_model_parallel_allreduce=False,
params_dtype=torch.float32,
use_cpu_initialization=False,
):
super(ColumnParallelLinear, self).__init__()
# Keep input parameters
self.input_size = input_size
self.output_size = output_size
self.gather_output = gather_output
# Divide the weight matrix along the last dimension.
world_size = get_tensor_model_parallel_world_size()
self.output_size_per_partition = divide(output_size, world_size)
self.skip_bias_add = skip_bias_add
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
# Initialize weight.
if use_cpu_initialization:
self.weight = Parameter(torch.empty(self.output_size_per_partition, self.input_size, dtype=params_dtype))
self.master_weight = _initialize_affine_weight_cpu(
self.weight,
self.output_size,
self.input_size,
self.output_size_per_partition,
0,
init_method,
stride=stride,
return_master_weight=keep_master_weight_for_test,
params_dtype=params_dtype,
)
else:
self.weight = Parameter(
torch.empty(
self.output_size_per_partition,
self.input_size,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
_initialize_affine_weight_gpu(self.weight, init_method, partition_dim=0, stride=stride)
if bias:
if use_cpu_initialization:
self.bias = Parameter(torch.empty(self.output_size_per_partition, dtype=params_dtype))
else:
self.bias = Parameter(
torch.empty(self.output_size_per_partition, device=torch.cuda.current_device(), dtype=params_dtype)
)
set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter("bias", None)
self.async_tensor_model_parallel_allreduce = (
not no_async_tensor_model_parallel_allreduce and
world_size > 1)
def forward(self, input_):
bias = self.bias if not self.skip_bias_add else None
if self.async_tensor_model_parallel_allreduce:
input_shape = input_.shape
input_ = input_.view(input_shape[0] * input_shape[1],input_shape[2])
# Matrix multiply with asynchronous all-reduce execution
output_parallel = column_parallel_linear(input_, self.weight, bias)
output_parallel = output_parallel.view(
input_shape[0], input_shape[1], output_parallel.shape[1])
else:
# Set up backprop all-reduce.
input_parallel = copy_to_tensor_model_parallel_region(input_)
# Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight, bias)
if self.gather_output:
# All-gather across the partitions.
output = gather_from_tensor_model_parallel_region(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
class RowParallelLinear(torch.nn.Module):
"""Linear layer with row parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its first dimension and X along its second dimension as:
- -
| A_1 |
| . |
A = | . | X = [X_1, ..., X_p]
| . |
| A_p |
- -
Arguments:
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
bias: If true, add bias. Note that bias is not parallelized.
input_is_parallel: If true, we assume that the input is already
split across the GPUs and we do not split
again.
init_method: method to initialize weights. Note that bias is always set
to zero.
stride: For the strided linear layers.
keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights
used for initialization.
skip_bias_add: This was added to enable performance optimization where bias
can be fused with other elementwise operations. We skip
adding bias but instead return it.
"""
def __init__(
self,
input_size,
output_size,
bias=True,
input_is_parallel=False,
init_method=init.xavier_normal_,
stride=1,
keep_master_weight_for_test=False,
skip_bias_add=False,
*,
params_dtype=torch.float32,
use_cpu_initialization=False,
):
super(RowParallelLinear, self).__init__()
# Keep input parameters
self.input_size = input_size
self.output_size = output_size
self.input_is_parallel = input_is_parallel
# Divide the weight matrix along the last dimension.
world_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, world_size)
self.skip_bias_add = skip_bias_add
# as an argument to this function?
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
# Initialize weight.
if use_cpu_initialization:
self.weight = Parameter(torch.empty(self.output_size, self.input_size_per_partition, dtype=params_dtype))
self.master_weight = _initialize_affine_weight_cpu(
self.weight,
self.output_size,
self.input_size,
self.input_size_per_partition,
1,
init_method,
stride=stride,
return_master_weight=keep_master_weight_for_test,
params_dtype=params_dtype,
)
else:
self.weight = Parameter(
torch.empty(
self.output_size,
self.input_size_per_partition,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
_initialize_affine_weight_gpu(self.weight, init_method, partition_dim=1, stride=stride)
if bias:
if use_cpu_initialization:
self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))
else:
self.bias = Parameter(
torch.empty(self.output_size, device=torch.cuda.current_device(), dtype=params_dtype)
)
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter("bias", None)
def forward(self, input_):
# Set up backprop all-reduce.
if self.input_is_parallel:
input_parallel = input_
else:
input_parallel = scatter_to_tensor_model_parallel_region(input_)
# Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight)
# All-reduce across all the partitions.
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
if not self.skip_bias_add:
output = output_ + self.bias if self.bias is not None else output_
output_bias = None
else:
output = output_
output_bias = self.bias
return output, output_bias
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from ..parallel_state import get_tensor_model_parallel_group
from ..parallel_state import get_tensor_model_parallel_world_size
from ..parallel_state import get_tensor_model_parallel_rank
from .utils import split_tensor_along_last_dim
def _reduce(input_):
"""All-reduce the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
if get_tensor_model_parallel_world_size() == 1:
return input_
# All-reduce.
torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group())
return input_
def _split(input_):
"""Split the tensor along its last dimension and keep the
corresponding slice."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
# Split along last dimension.
input_list = split_tensor_along_last_dim(input_, world_size)
# Note: torch.split does not create contiguous tensors by default.
rank = get_tensor_model_parallel_rank()
output = input_list[rank].contiguous()
return output
def _gather(input_):
"""Gather tensors and concatinate along the last dimension."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
# Size and dimension.
last_dim = input_.dim() - 1
rank = get_tensor_model_parallel_rank()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=get_tensor_model_parallel_group())
# Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=last_dim).contiguous()
return output
class _CopyToModelParallelRegion(torch.autograd.Function):
"""Pass the input to the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return input_
@staticmethod
def forward(ctx, input_):
return input_
@staticmethod
def backward(ctx, grad_output):
return _reduce(grad_output)
class _ReduceFromModelParallelRegion(torch.autograd.Function):
"""All-reduce the input from the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return _reduce(input_)
@staticmethod
def forward(ctx, input_):
return _reduce(input_)
@staticmethod
def backward(ctx, grad_output):
return grad_output
class _ScatterToModelParallelRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank."""
@staticmethod
def symbolic(graph, input_):
return _split(input_)
@staticmethod
def forward(ctx, input_):
return _split(input_)
@staticmethod
def backward(ctx, grad_output):
return _gather(grad_output)
class _GatherFromModelParallelRegion(torch.autograd.Function):
"""Gather the input from model parallel region and concatinate."""
@staticmethod
def symbolic(graph, input_):
return _gather(input_)
@staticmethod
def forward(ctx, input_):
return _gather(input_)
@staticmethod
def backward(ctx, grad_output):
return _split(grad_output)
# -----------------
# Helper functions.
# -----------------
def copy_to_tensor_model_parallel_region(input_):
return _CopyToModelParallelRegion.apply(input_)
def reduce_from_tensor_model_parallel_region(input_):
return _ReduceFromModelParallelRegion.apply(input_)
def scatter_to_tensor_model_parallel_region(input_):
return _ScatterToModelParallelRegion.apply(input_)
def gather_from_tensor_model_parallel_region(input_):
return _GatherFromModelParallelRegion.apply(input_)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
# A dictionary of all the memory buffers allocated.
_MEM_BUFFS = dict()
def allocate_mem_buff(name, numel, dtype, track_usage):
"""Allocate a memory buffer."""
assert name not in _MEM_BUFFS, "memory buffer {} already allocated.".format(name)
_MEM_BUFFS[name] = MemoryBuffer(name, numel, dtype, track_usage)
return _MEM_BUFFS[name]
def get_mem_buff(name):
"""Get the memory buffer."""
return _MEM_BUFFS[name]
class MemoryBuffer:
"""Contiguous memory buffer.
Allocate a contiguous memory of type `dtype` and size `numel`. It is
used to reduce memory fragmentation.
Usage: After the allocation, the `_start` index is set tot the first
index of the memory. A memory chunk starting from `_start` index
can be `allocated` for an input tensor, with the elements of the
tensor being coppied. The buffer can be reused by resetting the
`_start` index.
"""
def __init__(self, name, numel, dtype, track_usage):
if torch.distributed.get_rank() == 0:
element_size = torch.tensor([], dtype=dtype).element_size()
print(
"> building the {} memory buffer with {} num elements "
"and {} dtype ({:.1f} MB)...".format(name, numel, dtype, numel * element_size / 1024 / 1024),
flush=True,
)
self.name = name
self.numel = numel
self.dtype = dtype
self.data = torch.empty(self.numel, dtype=self.dtype, device=torch.cuda.current_device(), requires_grad=False)
# Index tracking the start of the free memory.
self._start = 0
# Values used for tracking usage.
self.track_usage = track_usage
if self.track_usage:
self.in_use_value = 0.0
self.total_value = 0.0
def reset(self):
"""Reset the buffer start index to the beginning of the buffer."""
self._start = 0
def is_in_use(self):
"""Whether the current buffer hold on to any memory."""
return self._start > 0
def numel_in_use(self):
"""Return number of elements in use."""
return self._start
def add(self, tensor):
"""Allocate a chunk of memory from the buffer to tensor and copy
the values."""
assert tensor.dtype == self.dtype, "Input tensor type {} different from buffer type {}".format(
tensor.dtype, self.dtype
)
# Number of elements of the input tensor.
tensor_numel = torch.numel(tensor)
new_start = self._start + tensor_numel
assert new_start <= self.numel, "Not enough memory left in the buffer ({} > {})".format(
tensor_numel, self.numel - self._start
)
# New tensor is a view into the memory.
new_tensor = self.data[self._start : new_start]
self._start = new_start
new_tensor = new_tensor.view(tensor.shape)
new_tensor.copy_(tensor)
# Return a pointer to the new tensor.
return new_tensor
def get_data(self):
"""Return the data currently in use."""
if self.track_usage:
self.in_use_value += float(self._start)
self.total_value += float(self.numel)
return self.data[: self._start]
def print_average_usage(self):
"""Print memory usage average over time. We would like this value
to be as high as possible."""
assert self.track_usage, "You need to enable track usage."
if torch.distributed.get_rank() == 0:
print(
" > usage of {} memory buffer: {:.2f} %".format(
self.name, self.in_use_value * 100.0 / self.total_value
),
flush=True,
)
class RingMemBuffer:
"""A ring of memory buffers."""
def __init__(self, name, num_buffers, numel, dtype, track_usage):
self.num_buffers = num_buffers
self.buffers = [
allocate_mem_buff(name + " {}".format(i), numel, dtype, track_usage) for i in range(num_buffers)
]
self._index = -1
def get_next_buffer(self):
self._index += 1
self._index = self._index % self.num_buffers
buff = self.buffers[self._index]
assert not buff.is_in_use(), "buffer is already in use."
return buff
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron number of micro-batches calculators."""
from abc import ABC
from abc import abstractmethod
def build_num_microbatches_calculator(args):
# Constant num micro-batches.
if args.rampup_batch_size is None:
num_microbatches_calculator = ConstantNumMicroBatches(
args.global_batch_size, args.micro_batch_size, args.data_parallel_size
)
if args.rank == 0:
print(
"setting number of micro-batches to constant {}".format(num_microbatches_calculator.get()), flush=True
)
else:
assert len(args.rampup_batch_size) == 3, (
"expected the following "
"format: --rampup-batch-size <start batch size> "
"<batch size incerement> <ramp-up samples>"
)
start_batch_size = int(args.rampup_batch_size[0])
batch_size_increment = int(args.rampup_batch_size[1])
ramup_samples = int(args.rampup_batch_size[2])
if args.rank == 0:
print(
"will use batch size rampup starting from global batch "
"size {} to global batch size {} with batch size increments "
"{} over {} samples.".format(
start_batch_size, args.global_batch_size, batch_size_increment, ramup_samples
),
flush=True,
)
num_microbatches_calculator = RampupBatchsizeNumMicroBatches(
start_batch_size,
batch_size_increment,
ramup_samples,
args.global_batch_size,
args.micro_batch_size,
args.data_parallel_size,
)
return num_microbatches_calculator
class NumMicroBatchesCalculator(ABC):
def __init__(self):
self.num_micro_batches = None
self.current_global_batch_size = None
def get(self):
return self.num_micro_batches
def get_current_global_batch_size(self):
return self.current_global_batch_size
@abstractmethod
def update(self, consumed_samples, consistency_check):
pass
class ConstantNumMicroBatches(NumMicroBatchesCalculator):
def __init__(self, global_batch_size, micro_batch_size, data_parallel_size):
micro_batch_times_data_parallel = micro_batch_size * data_parallel_size
assert global_batch_size % micro_batch_times_data_parallel == 0, (
"global batch size ({}) is not divisible by micro batch size ({})"
" times data parallel size ({})".format(global_batch_size, micro_batch_size, data_parallel_size)
)
self.num_micro_batches = global_batch_size // micro_batch_times_data_parallel
assert self.num_micro_batches >= 1
self.current_global_batch_size = global_batch_size
def update(self, consumed_samples, consistency_check):
pass
class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator):
def __init__(
self,
start_batch_size,
batch_size_increment,
ramup_samples,
global_batch_size,
micro_batch_size,
data_parallel_size,
):
"""Batch size ramp up.
Over
steps = (global-batch-size - start-batch-size) / batch_size_increment
increment batch size from start-batch-size to global-batch-size using
rampup-samples / steps
samples.
Arguments:
start_batch_size: global batch size to start with
batch_size_increment: global batch size increments
ramup_samples: number of samples to use ramp up global
batch size from `start_batch_size` to `global_batch_size`
global_batch_size: global batch size post rampup
micro_batch_size: micro batch size
data_parallel_size: data parallel size.
"""
self.micro_batch_size = micro_batch_size
self.data_parallel_size = data_parallel_size
self.micro_batch_times_data_parallel_size = self.micro_batch_size * self.data_parallel_size
assert self.micro_batch_times_data_parallel_size > 0
assert start_batch_size > 0
self.start_batch_size = start_batch_size
assert global_batch_size > 0
self.global_batch_size = global_batch_size
diff_batch_size = self.global_batch_size - self.start_batch_size
assert diff_batch_size >= 0
assert batch_size_increment > 0
self.batch_size_increment = batch_size_increment
assert diff_batch_size % batch_size_increment == 0, (
"expected "
"global batch size interval ({}) to be divisible by global batch "
"size increment ({})".format(diff_batch_size, batch_size_increment)
)
num_increments = diff_batch_size // self.batch_size_increment
self.ramup_samples = ramup_samples
assert self.ramup_samples >= 0
self.rampup_samples_per_increment = self.ramup_samples / num_increments
# Initialize number of microbatches.
self.update(0, False)
def update(self, consumed_samples, consistency_check):
if consumed_samples > self.ramup_samples:
self.current_global_batch_size = self.global_batch_size
else:
steps = int(consumed_samples / self.rampup_samples_per_increment)
self.current_global_batch_size = self.start_batch_size + steps * self.batch_size_increment
assert self.current_global_batch_size <= self.global_batch_size
if consistency_check:
assert self.current_global_batch_size % self.micro_batch_times_data_parallel_size == 0, (
"current global "
"batch size ({}) is not divisible by micro-batch-size ({}) times"
"data parallel size ({})".format(
self.current_global_batch_size, self.micro_batch_size, self.data_parallel_size
)
)
self.num_micro_batches = self.current_global_batch_size // self.micro_batch_times_data_parallel_size
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch
import contextlib
import torch
from torch import _C
from torch.cuda import _lazy_call, device as device_ctx_manager
from torch.utils.checkpoint import detach_variable
from ..parallel_state import get_data_parallel_rank
from ..parallel_state import get_tensor_model_parallel_group
from ..parallel_state import get_tensor_model_parallel_rank
from ..parallel_state import get_tensor_model_parallel_world_size
from .memory import allocate_mem_buff
# Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME = "model-parallel-rng"
# Whether apply model parallelsim to checkpointed hidden states.
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = None
# TODO (mkozuki): Consider the possibility of removing `tensor_model_parallel_size`,
# `get_tensor_model_parallel_world_size()` might be alternative.
def init_checkpointed_activations_memory_buffer(
micro_batch_size,
max_position_embeddings,
hidden_size,
num_layers,
tensor_model_parallel_size,
checkpoint_num_layers,
fp16,
):
"""Initializ the memory buffer for the checkpointed activations."""
per_layer = micro_batch_size * max_position_embeddings * hidden_size // tensor_model_parallel_size
assert num_layers % checkpoint_num_layers == 0, "number of layers is not divisible by checkpoint-num-layers"
num_checkpointer_layers = num_layers // checkpoint_num_layers
numel = per_layer * num_checkpointer_layers
dtype = torch.half
if not fp16:
dtype = torch.float
global _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER
assert (
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is None
), "checkpointed activations memory buffer is already allocated."
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = allocate_mem_buff(
"checkpointed activations", numel, dtype, track_usage=False
)
def reset_checkpointed_activations_memory_buffer():
"""Reset the memory used for checkpointing."""
if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None:
_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER.reset()
def _set_cuda_rng_state(new_state, device=-1):
"""Sets the random number generator state of the current GPU.
Argumentss:
new_state (torch.ByteTensor): The desired state
This function is adapted from PyTorch repo (torch.cuda.set_rng_state)
with a single change: the input state is not cloned. Cloning caused
major performance issues for +4 GPU cases.
"""
if hasattr(_C, "_cuda_setRNGState") and callable(_C._cuda_setRNGState):
# older PyTorch
def cb():
with device_ctx_manager(device):
_C._cuda_setRNGState(new_state)
else:
# newer PyTorch
if device == -1:
device = torch.device("cuda")
elif isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device("cuda", device)
def cb():
idx = device.index
if idx is None:
idx = torch.cuda.current_device()
default_generator = torch.cuda.default_generators[idx]
default_generator.set_state(new_state)
_lazy_call(cb)
def split_tensor_into_1d_equal_chunks(tensor):
"""Break a tensor into equal 1D chunks."""
data = tensor.view(-1)
partition_size = torch.numel(data) // get_tensor_model_parallel_world_size()
start_index = partition_size * get_tensor_model_parallel_rank()
end_index = start_index + partition_size
return data[start_index:end_index]
def gather_split_1d_tensor(tensor):
"""Opposite of above function, gather values from model parallel ranks."""
world_size = get_tensor_model_parallel_world_size()
numel = torch.numel(tensor)
numel_gathered = world_size * numel
gathered = torch.empty(numel_gathered, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False)
chunks = [gathered[i * numel : (i + 1) * numel] for i in range(world_size)]
torch.distributed.all_gather(chunks, tensor, group=get_tensor_model_parallel_group())
return gathered
class CudaRNGStatesTracker:
"""Tracker for the cuda RNG states.
Using the `add` method, a cuda rng state is initialized based on
the input `seed` and is assigned to `name`. Later, by forking the
rng state, we can perform operations and return to our starting
cuda state.
"""
def __init__(self):
# Map from a string name to the cuda rng state.
self.states_ = {}
# Seeds are just for book keeping and ensure no seed is set twice.
self.seeds_ = set()
def reset(self):
"""Set to the initial state (no tracker)."""
self.states_ = {}
self.seeds_ = set()
def get_states(self):
"""Get rng states. Copy the dictionary so we have direct
pointers to the states, not just a pointer to the dictionary."""
states = {}
for name in self.states_:
states[name] = self.states_[name]
return states
def set_states(self, states):
"""Set the rng states. For efficiency purposes, we do not check
the size of seed for compatibility."""
self.states_ = states
def add(self, name, seed):
"""Track the rng state."""
# Check seed is not already used.
if seed in self.seeds_:
raise Exception("seed {} already exists".format(seed))
self.seeds_.add(seed)
# Check that state is not already defined.
if name in self.states_:
raise Exception("cuda rng state {} already exists".format(name))
# Get the current rng state.
orig_rng_state = torch.cuda.get_rng_state()
# Set the new state and store it.
torch.cuda.manual_seed(seed)
self.states_[name] = torch.cuda.get_rng_state()
# Reset rng state to what it was.
_set_cuda_rng_state(orig_rng_state)
@contextlib.contextmanager
def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME):
"""Fork the cuda rng state, perform operations, and exit with
the original state."""
# Check if we have added the state
if name not in self.states_:
raise Exception("cuda rng state {} is not added".format(name))
# Store current rng state.
orig_cuda_rng_state = torch.cuda.get_rng_state()
# Set rng state to the desired one
_set_cuda_rng_state(self.states_[name])
# Do the stuff we wanted to do.
try:
yield
finally:
# Update the current rng state for later use.
self.states_[name] = torch.cuda.get_rng_state()
# And set the state to the original state we started with.
_set_cuda_rng_state(orig_cuda_rng_state)
# RNG tracker object.
_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
def get_cuda_rng_tracker():
"""Get cuda rng tracker."""
return _CUDA_RNG_STATE_TRACKER
def model_parallel_cuda_manual_seed(seed):
"""Initialize model parallel cuda seed.
This function should be called after the model parallel is
initialized. Also, no torch.cuda.manual_seed should be called
after this function. Basically, this is replacement for that
function.
Two set of RNG states are tracked:
default state: This is for data parallelism and is the same among a
set of model parallel GPUs but different across
different model paralle groups. This is used for
example for dropout in the non-tensor-model-parallel regions.
tensor-model-parallel state: This state is different among a set of model
parallel GPUs, but the same across data parallel
groups. This is used for example for dropout in
model parallel regions.
"""
# 2718 is just for fun and any POSITIVE value will work.
offset = seed + 2718
tensor_model_parallel_seed = offset + get_tensor_model_parallel_rank()
# Data parallel gets the original seed.
data_parallel_seed = seed
_CUDA_RNG_STATE_TRACKER.reset()
# Set the default state.
torch.cuda.manual_seed(data_parallel_seed)
# and model parallel state.
_CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, tensor_model_parallel_seed)
class CheckpointFunction(torch.autograd.Function):
"""This function is adapted from torch.utils.checkpoint with
two main changes:
1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state`
2) the states in the model parallel tracker are also properly
tracked/set/reset.
"""
@staticmethod
def forward(ctx, run_function, *args):
ctx.run_function = run_function
# Copy the rng states.
ctx.fwd_cpu_rng_state = torch.get_rng_state()
ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()
ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
with torch.no_grad():
outputs = run_function(*args)
# Divide hidden states across model parallel group and only keep
# the chunk corresponding to the current rank.
if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None:
ctx.input_0_shape = args[0].data.shape
args[0].data = split_tensor_into_1d_equal_chunks(args[0].data)
args[0].data = _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER.add(args[0].data)
# Store everything.
ctx.save_for_backward(*args)
return outputs
@staticmethod
def backward(ctx, *args):
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError("Checkpointing is not compatible with .grad(), " "please use .backward() if possible")
inputs = ctx.saved_tensors
if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None:
inputs[0].data = gather_split_1d_tensor(inputs[0].data)
inputs[0].data = inputs[0].data.view(ctx.input_0_shape)
# Store the current states.
bwd_cpu_rng_state = torch.get_rng_state()
bwd_cuda_rng_state = torch.cuda.get_rng_state()
bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
# Set the states to what it used to be before the forward pass.
torch.set_rng_state(ctx.fwd_cpu_rng_state)
_set_cuda_rng_state(ctx.fwd_cuda_rng_state)
get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)
# Compute the forward pass.
detached_inputs = detach_variable(inputs)
with torch.enable_grad():
outputs = ctx.run_function(*detached_inputs)
# Set the states back to what it was at the start of this function.
torch.set_rng_state(bwd_cpu_rng_state)
_set_cuda_rng_state(bwd_cuda_rng_state)
get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
torch.autograd.backward(outputs, args)
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs)
return (None,) + grads
def checkpoint(function, *args):
"""Checkpoint a model or part of the model.
This has been directly copied from torch.utils.checkpoint."""
return CheckpointFunction.apply(function, *args)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron arguments."""
import argparse
import os
import torch
def parse_args(extra_args_provider=None, defaults={},
ignore_unknown_args=False):
"""Parse all arguments."""
parser = argparse.ArgumentParser(description='Megatron-LM Arguments',
allow_abbrev=False)
# Standard arguments.
parser = _add_network_size_args(parser)
parser = _add_regularization_args(parser)
parser = _add_training_args(parser)
parser = _add_initialization_args(parser)
parser = _add_learning_rate_args(parser)
parser = _add_checkpointing_args(parser)
parser = _add_mixed_precision_args(parser)
parser = _add_distributed_args(parser)
parser = _add_validation_args(parser)
parser = _add_data_args(parser)
parser = _add_autoresume_args(parser)
parser = _add_biencoder_args(parser)
parser = _add_vit_args(parser)
parser = _add_logging_args(parser)
# Custom arguments.
if extra_args_provider is not None:
parser = extra_args_provider(parser)
# Parse.
if ignore_unknown_args:
args, _ = parser.parse_known_args()
else:
args = parser.parse_args()
# Distributed args.
args.rank = int(os.getenv('RANK', '0'))
args.world_size = int(os.getenv("WORLD_SIZE", '1'))
# Tensor model parallel size.
args.tensor_model_parallel_size = min(
args.tensor_model_parallel_size, args.world_size)
assert args.world_size % args.tensor_model_parallel_size == 0, 'world size'\
' ({}) is not divisible by tensor model parallel size ({})'.format(
args.world_size, args.tensor_model_parallel_size)
# Pipeline model parallel size.
args.pipeline_model_parallel_size = min(
args.pipeline_model_parallel_size,
(args.world_size // args.tensor_model_parallel_size))
# Checks.
model_parallel_size = args.pipeline_model_parallel_size * \
args.tensor_model_parallel_size
assert args.world_size % model_parallel_size == 0, 'world size is not'\
' divisible by tensor parallel size ({}) times pipeline parallel ' \
'size ({})'.format(args.world_size, args.tensor_model_parallel_size,
args.pipeline_model_parallel_size)
args.data_parallel_size = args.world_size // model_parallel_size
if args.rank == 0:
print('using world size: {}, data-parallel-size: {}, '
'tensor-model-parallel size: {}, '
'pipeline-model-parallel size: {} '.format(
args.world_size, args.data_parallel_size,
args.tensor_model_parallel_size,
args.pipeline_model_parallel_size), flush=True)
# Deprecated arguments
assert args.batch_size is None, '--batch-size argument is no longer ' \
'valid, use --micro-batch-size instead'
del args.batch_size
assert args.warmup is None, '--warmup argument is no longer valid, use ' \
'--lr-warmup-fraction instead'
del args.warmup
assert args.model_parallel_size is None, '--model-parallel-size is no ' \
'longer valid, use --tensor-model-parallel-size instead'
del args.model_parallel_size
# Set input defaults.
for key in defaults:
# For default to be valid, it should not be provided in the
# arguments that are passed to the program. We check this by
# ensuring the arg is set to None.
if getattr(args, key) is not None:
if args.rank == 0:
print('WARNING: overriding default arguments for {key}:{v} \
with {key}:{v2}'.format(key=key, v=defaults[key],
v2=getattr(args, key)),
flush=True)
else:
setattr(args, key, defaults[key])
# Batch size.
assert args.micro_batch_size is not None
assert args.micro_batch_size > 0
if args.global_batch_size is None:
args.global_batch_size = args.micro_batch_size * args.data_parallel_size
if args.rank == 0:
print('setting global batch size to {}'.format(
args.global_batch_size), flush=True)
assert args.global_batch_size > 0
if args.num_layers_per_virtual_pipeline_stage is not None:
assert args.pipeline_model_parallel_size > 2, \
'pipeline-model-parallel size should be greater than 2 with ' \
'interleaved schedule'
assert args.num_layers % args.num_layers_per_virtual_pipeline_stage == 0, \
'number of layers is not divisible by number of layers per virtual ' \
'pipeline stage'
args.virtual_pipeline_model_parallel_size = \
(args.num_layers // args.pipeline_model_parallel_size) // \
args.num_layers_per_virtual_pipeline_stage
else:
args.virtual_pipeline_model_parallel_size = None
# Parameters dtype.
args.params_dtype = torch.float
if args.fp16:
assert not args.bf16
args.params_dtype = torch.half
if args.bf16:
assert not args.fp16
args.params_dtype = torch.bfloat16
# bfloat16 requires gradient accumulation and all-reduce to
# be done in fp32.
if not args.accumulate_allreduce_grads_in_fp32:
args.accumulate_allreduce_grads_in_fp32 = True
if args.rank == 0:
print('accumulate and all-reduce gradients in fp32 for '
'bfloat16 data type.', flush=True)
if args.rank == 0:
print('using {} for parameters ...'.format(args.params_dtype),
flush=True)
# If we do accumulation and all-reduces in fp32, we need to have
# local DDP and we should set the use-contiguous-buffers-in-ddp.
if args.accumulate_allreduce_grads_in_fp32:
assert args.DDP_impl == 'local'
args.use_contiguous_buffers_in_ddp = True
# If we use a contiguous buffer to hold main grads, we need to have
# local DDP.
if args.use_contiguous_buffers_in_ddp:
assert args.DDP_impl == 'local'
if args.dataloader_type is None:
args.dataloader_type = 'single'
# Consumed tokens.
args.consumed_train_samples = 0
args.consumed_valid_samples = 0
# Iteration-based training.
if args.train_iters:
# If we use iteration-based training, make sure the
# sample-based options are off.
assert args.train_samples is None, \
'expected iteration-based training'
assert args.lr_decay_samples is None, \
'expected iteration-based learning rate decay'
assert args.lr_warmup_samples == 0, \
'expected iteration-based learning rate warmup'
assert args.rampup_batch_size is None, \
'expected no batch-size rampup for iteration-based training'
if args.lr_warmup_fraction is not None:
assert args.lr_warmup_iters == 0, \
'can only specify one of lr-warmup-fraction and lr-warmup-iters'
# Sample-based training.
if args.train_samples:
# If we use sample-based training, make sure the
# iteration-based options are off.
assert args.train_iters is None, \
'expected sample-based training'
assert args.lr_decay_iters is None, \
'expected sample-based learning rate decay'
assert args.lr_warmup_iters == 0, \
'expected sample-based learnig rate warmup'
if args.lr_warmup_fraction is not None:
assert args.lr_warmup_samples == 0, \
'can only specify one of lr-warmup-fraction ' \
'and lr-warmup-samples'
# Check required arguments.
required_args = ['num_layers', 'hidden_size', 'num_attention_heads',
'max_position_embeddings']
for req_arg in required_args:
_check_arg_is_not_none(args, req_arg)
# Checks.
if args.ffn_hidden_size is None:
args.ffn_hidden_size = 4 * args.hidden_size
if args.kv_channels is None:
assert args.hidden_size % args.num_attention_heads == 0
args.kv_channels = args.hidden_size // args.num_attention_heads
if args.seq_length is not None:
assert args.encoder_seq_length is None
args.encoder_seq_length = args.seq_length
else:
assert args.encoder_seq_length is not None
args.seq_length = args.encoder_seq_length
if args.seq_length is not None:
assert args.max_position_embeddings >= args.seq_length
if args.decoder_seq_length is not None:
assert args.max_position_embeddings >= args.decoder_seq_length
if args.lr is not None:
assert args.min_lr <= args.lr
if args.save is not None:
assert args.save_interval is not None
# Mixed precision checks.
if args.fp16_lm_cross_entropy:
assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'
if args.fp32_residual_connection:
assert args.fp16 or args.bf16, \
'residual connection in fp32 only supported when using fp16 or bf16.'
# Activation checkpointing.
if args.distribute_checkpointed_activations:
assert args.checkpoint_activations, \
'for distribute-checkpointed-activations to work you '\
'need to enable checkpoint-activations'
_print_args(args)
return args
def _print_args(args):
"""Print arguments."""
if args.rank == 0:
print('------------------------ arguments ------------------------',
flush=True)
str_list = []
for arg in vars(args):
dots = '.' * (48 - len(arg))
str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg)))
for arg in sorted(str_list, key=lambda x: x.lower()):
print(arg, flush=True)
print('-------------------- end of arguments ---------------------',
flush=True)
def _check_arg_is_not_none(args, arg):
assert getattr(args, arg) is not None, '{} argument is None'.format(arg)
def _add_network_size_args(parser):
group = parser.add_argument_group(title='network size')
group.add_argument('--num-layers', type=int, default=None,
help='Number of transformer layers.')
group.add_argument('--hidden-size', type=int, default=None,
help='Tansformer hidden size.')
group.add_argument('--ffn-hidden-size', type=int, default=None,
help='Transformer Feed-Forward Network hidden size. '
'This is set to 4*hidden-size if not provided')
group.add_argument('--num-attention-heads', type=int, default=None,
help='Number of transformer attention heads.')
group.add_argument('--kv-channels', type=int, default=None,
help='Projection weights dimension in multi-head '
'attention. This is set to '
' args.hidden_size // args.num_attention_heads '
'if not provided.')
group.add_argument('--max-position-embeddings', type=int, default=None,
help='Maximum number of position embeddings to use. '
'This is the size of position embedding.')
group.add_argument('--make-vocab-size-divisible-by', type=int, default=128,
help='Pad the vocab size to be divisible by this value.'
'This is added for computational efficieny reasons.')
group.add_argument('--layernorm-epsilon', type=float, default=1e-5,
help='Layer norm epsilon.')
group.add_argument('--apply-residual-connection-post-layernorm',
action='store_true',
help='If set, use original BERT residula connection '
'ordering.')
group.add_argument('--openai-gelu', action='store_true',
help='Use OpenAIs GeLU implementation. This option'
'should not be used unless for backward compatibility'
'reasons.')
group.add_argument('--onnx-safe', type=bool, required=False,
help='Use workarounds for known problems with '
'Torch ONNX exporter')
group.add_argument('--bert-no-binary-head', action='store_false',
help='Disable BERT binary head.',
dest='bert_binary_head')
return parser
def _add_logging_args(parser):
group = parser.add_argument_group(title='logging')
group.add_argument('--log-params-norm', action='store_true',
help='If set, calculate and log parameters norm.')
group.add_argument('--log-num-zeros-in-grad', action='store_true',
help='If set, calculate and log the number of zeros in gradient.')
group.add_argument('--tensorboard-log-interval', type=int, default=1,
help='Report to tensorboard interval.')
group.add_argument('--tensorboard-queue-size', type=int, default=1000,
help='Size of the tensorboard queue for pending events '
'and summaries before one of the ‘add’ calls forces a '
'flush to disk.')
group.add_argument('--log-timers-to-tensorboard', action='store_true',
help='If set, write timers to tensorboard.')
group.add_argument('--log-batch-size-to-tensorboard', action='store_true',
help='If set, write batch-size to tensorboard.')
group.add_argument('--no-log-learnig-rate-to-tensorboard',
action='store_false',
help='Disable learning rate logging to tensorboard.',
dest='log_learning_rate_to_tensorboard')
group.add_argument('--no-log-loss-scale-to-tensorboard',
action='store_false',
help='Disable loss-scale logging to tensorboard.',
dest='log_loss_scale_to_tensorboard')
group.add_argument('--log-validation-ppl-to-tensorboard',
action='store_true',
help='If set, write validation perplexity to '
'tensorboard.')
group.add_argument('--log-memory-to-tensorboard',
action='store_true',
help='Enable memory logging to tensorboard.')
return parser
def _add_regularization_args(parser):
group = parser.add_argument_group(title='regularization')
group.add_argument('--attention-dropout', type=float, default=0.1,
help='Post attention dropout probability.')
group.add_argument('--hidden-dropout', type=float, default=0.1,
help='Dropout probability for hidden state transformer.')
group.add_argument('--weight-decay', type=float, default=0.01,
help='Weight decay coefficient for L2 regularization.')
group.add_argument('--clip-grad', type=float, default=1.0,
help='Gradient clipping based on global L2 norm.')
group.add_argument('--adam-beta1', type=float, default=0.9,
help='First coefficient for computing running averages '
'of gradient and its square')
group.add_argument('--adam-beta2', type=float, default=0.999,
help='Second coefficient for computing running averages '
'of gradient and its square')
group.add_argument('--adam-eps', type=float, default=1e-08,
help='Term added to the denominator to improve'
'numerical stability')
group.add_argument('--sgd-momentum', type=float, default=0.9,
help='Momentum factor for sgd')
return parser
def _add_training_args(parser):
group = parser.add_argument_group(title='training')
group.add_argument('--micro-batch-size', type=int, default=None,
help='Batch size per model instance (local batch size). '
'Global batch size is local batch size times data '
'parallel size times number of micro batches.')
group.add_argument('--batch-size', type=int, default=None,
help='Old batch size parameter, do not use. '
'Use --micro-batch-size instead')
group.add_argument('--global-batch-size', type=int, default=None,
help='Training batch size. If set, it should be a '
'multiple of micro-batch-size times data-parallel-size. '
'If this value is None, then '
'use micro-batch-size * data-parallel-size as the '
'global batch size. This choice will result in 1 for '
'number of micro-batches.')
group.add_argument('--rampup-batch-size', nargs='*', default=None,
help='Batch size ramp up with the following values:'
' --rampup-batch-size <start batch size> '
' <batch size incerement> '
' <ramp-up samples> '
'For example:'
' --rampup-batch-size 16 8 300000 \ '
' --global-batch-size 1024'
'will start with global batch size 16 and over '
' (1024 - 16) / 8 = 126 intervals will increase'
'the batch size linearly to 1024. In each interval'
'we will use approximately 300000 / 126 = 2380 samples.')
group.add_argument('--checkpoint-activations', action='store_true',
help='Checkpoint activation to allow for training '
'with larger models, sequences, and batch sizes.')
group.add_argument('--distribute-checkpointed-activations',
action='store_true',
help='If set, distribute checkpointed activations '
'across model parallel group.')
group.add_argument('--checkpoint-num-layers', type=int, default=1,
help='chunk size (number of layers) for checkpointing.')
group.add_argument('--train-iters', type=int, default=None,
help='Total number of iterations to train over all '
'training runs. Note that either train-iters or '
'train-samples should be provided.')
group.add_argument('--train-samples', type=int, default=None,
help='Total number of samples to train over all '
'training runs. Note that either train-iters or '
'train-samples should be provided.')
group.add_argument('--log-interval', type=int, default=100,
help='Report loss and timing interval.')
group.add_argument('--exit-interval', type=int, default=None,
help='Exit the program after the iteration is divisible '
'by this value.')
group.add_argument('--exit-duration-in-mins', type=int, default=None,
help='Exit the program after this many minutes.')
group.add_argument('--tensorboard-dir', type=str, default=None,
help='Write TensorBoard logs to this directory.')
group.add_argument('--no-masked-softmax-fusion',
action='store_false',
help='Disable fusion of query_key_value scaling, '
'masking, and softmax.',
dest='masked_softmax_fusion')
group.add_argument('--no-bias-gelu-fusion', action='store_false',
help='Disable bias and gelu fusion.',
dest='bias_gelu_fusion')
group.add_argument('--no-bias-dropout-fusion', action='store_false',
help='Disable bias and dropout fusion.',
dest='bias_dropout_fusion')
group.add_argument('--optimizer', type=str, default='adam',
choices=['adam', 'sgd'],
help='Optimizer function')
group.add_argument('--dataloader-type', type=str, default=None,
choices=['single', 'cyclic'],
help='Single pass vs multiple pass data loader')
return parser
def _add_initialization_args(parser):
group = parser.add_argument_group(title='initialization')
group.add_argument('--seed', type=int, default=1234,
help='Random seed used for python, numpy, '
'pytorch, and cuda.')
group.add_argument('--init-method-std', type=float, default=0.02,
help='Standard deviation of the zero mean normal '
'distribution used for weight initialization.')
group.add_argument('--init-method-xavier-uniform', action='store_true',
help='Enable Xavier uniform parameter initialization')
return parser
def _add_learning_rate_args(parser):
group = parser.add_argument_group(title='learning rate')
group.add_argument('--lr', type=float, default=None,
help='Initial learning rate. Depending on decay style '
'and initial warmup, the learing rate at each '
'iteration would be different.')
group.add_argument('--lr-decay-style', type=str, default='linear',
choices=['constant', 'linear', 'cosine'],
help='Learning rate decay function.')
group.add_argument('--lr-decay-iters', type=int, default=None,
help='number of iterations to decay learning rate over,'
' If None defaults to `--train-iters`')
group.add_argument('--lr-decay-samples', type=int, default=None,
help='number of samples to decay learning rate over,'
' If None defaults to `--train-samples`')
group.add_argument('--lr-warmup-fraction', type=float, default=None,
help='fraction of lr-warmup-(iters/samples) to use '
'for warmup (as a float)')
group.add_argument('--lr-warmup-iters', type=int, default=0,
help='number of iterations to linearly warmup '
'learning rate over.')
group.add_argument('--lr-warmup-samples', type=int, default=0,
help='number of samples to linearly warmup '
'learning rate over.')
group.add_argument('--warmup', type=int, default=None,
help='Old lr warmup argument, do not use. Use one of the'
'--lr-warmup-* arguments above')
group.add_argument('--min-lr', type=float, default=0.0,
help='Minumum value for learning rate. The scheduler'
'clip values below this threshold.')
group.add_argument('--override-lr-scheduler', action='store_true',
help='Reset the values of the scheduler (learning rate,'
'warmup iterations, minimum learning rate, maximum '
'number of iterations, and decay style from input '
'arguments and ignore values from checkpoints. Note'
'that all the above values will be reset.')
group.add_argument('--use-checkpoint-lr-scheduler', action='store_true',
help='Use checkpoint to set the values of the scheduler '
'(learning rate, warmup iterations, minimum learning '
'rate, maximum number of iterations, and decay style '
'from checkpoint and ignore input arguments.')
return parser
def _add_checkpointing_args(parser):
group = parser.add_argument_group(title='checkpointing')
group.add_argument('--save', type=str, default=None,
help='Output directory to save checkpoints to.')
group.add_argument('--save-interval', type=int, default=None,
help='Number of iterations between checkpoint saves.')
group.add_argument('--no-save-optim', action='store_true', default=None,
help='Do not save current optimizer.')
group.add_argument('--no-save-rng', action='store_true', default=None,
help='Do not save current rng state.')
group.add_argument('--load', type=str, default=None,
help='Directory containing a model checkpoint.')
group.add_argument('--no-load-optim', action='store_true', default=None,
help='Do not load optimizer when loading checkpoint.')
group.add_argument('--no-load-rng', action='store_true', default=None,
help='Do not load rng state when loading checkpoint.')
group.add_argument('--finetune', action='store_true',
help='Load model for finetuning. Do not load optimizer '
'or rng state from checkpoint and set iteration to 0. '
'Assumed when loading a release checkpoint.')
return parser
def _add_mixed_precision_args(parser):
group = parser.add_argument_group(title='mixed precision')
group.add_argument('--fp16', action='store_true',
help='Run model in fp16 mode.')
group.add_argument('--bf16', action='store_true',
help='Run model in bfloat16 mode.')
group.add_argument('--loss-scale', type=float, default=None,
help='Static loss scaling, positive power of 2 '
'values can improve fp16 convergence. If None, dynamic'
'loss scaling is used.')
group.add_argument('--initial-loss-scale', type=float, default=2**32,
help='Initial loss-scale for dynamic loss scaling.')
group.add_argument('--min-loss-scale', type=float, default=1.0,
help='Minimum loss scale for dynamic loss scale.')
group.add_argument('--loss-scale-window', type=float, default=1000,
help='Window over which to raise/lower dynamic scale.')
group.add_argument('--hysteresis', type=int, default=2,
help='hysteresis for dynamic loss scaling')
group.add_argument('--fp32-residual-connection', action='store_true',
help='Move residual connections to fp32.')
group.add_argument('--no-query-key-layer-scaling', action='store_false',
help='Do not scale Q * K^T by 1 / layer-number.',
dest='apply_query_key_layer_scaling')
group.add_argument('--attention-softmax-in-fp32', action='store_true',
help='Run attention masking and softmax in fp32. '
'This flag is ignored unless '
'--no-query-key-layer-scaling is specified.')
group.add_argument('--accumulate-allreduce-grads-in-fp32',
action='store_true',
help='Gradient accumulation and all-reduce in fp32.')
group.add_argument('--fp16-lm-cross-entropy', action='store_true',
help='Move the cross entropy unreduced loss calculation'
'for lm head to fp16.')
return parser
def _add_distributed_args(parser):
group = parser.add_argument_group(title='distributed')
group.add_argument('--tensor-model-parallel-size', type=int, default=1,
help='Degree of tensor model parallelism.')
group.add_argument('--pipeline-model-parallel-size', type=int, default=1,
help='Degree of pipeline model parallelism.')
group.add_argument('--model-parallel-size', type=int, default=None,
help='Old model parallel argument, do not use. Use '
'--tensor-model-parallel-size instead.')
group.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None,
help='Number of layers per virtual pipeline stage')
group.add_argument('--distributed-backend', default='nccl',
choices=['nccl', 'gloo'],
help='Which backend to use for distributed training.')
group.add_argument('--DDP-impl', default='local',
choices=['local', 'torch'],
help='which DistributedDataParallel implementation '
'to use.')
group.add_argument('--use-contiguous-buffers-in-ddp', action='store_true',
help='If set, use contiguous buffer in DDP. Note that '
'this option only works woth local DDP.' )
group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false',
help='Use scatter/gather to optimize communication of tensors in pipeline',
dest='scatter_gather_tensors_in_pipeline')
group.add_argument('--local_rank', type=int, default=None,
help='local rank passed from distributed launcher.')
group.add_argument('--lazy-mpu-init', type=bool, required=False,
help='If set to True, initialize_megatron() '
'skips DDP initialization and returns function to '
'complete it instead.Also turns on '
'--use-cpu-initialization flag. This is for '
'external DDP manager.' )
group.add_argument('--use-cpu-initialization', action='store_true',
default=None, help='If set, affine parallel weights '
'initialization uses CPU' )
group.add_argument('--empty-unused-memory-level', default=0, type=int,
choices=[0, 1, 2],
help='Call torch.cuda.empty_cache() each iteration '
'(training and eval), to reduce fragmentation.'
'0=off, 1=moderate, 2=aggressive.')
return parser
def _add_validation_args(parser):
group = parser.add_argument_group(title='validation')
group.add_argument('--eval-iters', type=int, default=100,
help='Number of iterations to run for evaluation'
'validation/test for.')
group.add_argument('--eval-interval', type=int, default=1000,
help='Interval between running evaluation on '
'validation set.')
return parser
def _add_data_args(parser):
group = parser.add_argument_group(title='data and dataloader')
group.add_argument('--data-path', nargs='*', default=None,
help='Path to the training dataset. Accepted format:'
'1) a single data path, 2) multiple datasets in the'
'form: dataset1-weight dataset1-path dataset2-weight '
'dataset2-path ...')
group.add_argument('--split', type=str, default='969, 30, 1',
help='Comma-separated list of proportions for training,'
' validation, and test split. For example the split '
'`90,5,5` will use 90%% of data for training, 5%% for '
'validation and 5%% for test.')
group.add_argument('--vocab-file', type=str, default=None,
help='Path to the vocab file.')
group.add_argument('--merge-file', type=str, default=None,
help='Path to the BPE merge file.')
group.add_argument('--vocab-extra-ids', type=int, default=0,
help='Number of additional vocabulary tokens. '
'They are used for span masking in the T5 model')
group.add_argument('--seq-length', type=int, default=None,
help='Maximum sequence length to process.')
group.add_argument('--encoder-seq-length', type=int, default=None,
help='Maximum encoder sequence length to process.'
'This should be exclusive of --seq-length')
group.add_argument('--decoder-seq-length', type=int, default=None,
help="Maximum decoder sequence length to process.")
group.add_argument('--retriever-seq-length', type=int, default=256,
help='Maximum sequence length for the biencoder model '
' for retriever')
group.add_argument('--sample-rate', type=float, default=1.0,
help='sample rate for training data. Supposed to be 0 '
' < sample_rate < 1')
group.add_argument('--mask-prob', type=float, default=0.15,
help='Probability of replacing a token with mask.')
group.add_argument('--short-seq-prob', type=float, default=0.1,
help='Probability of producing a short sequence.')
group.add_argument('--mmap-warmup', action='store_true',
help='Warm up mmap files.')
group.add_argument('--num-workers', type=int, default=2,
help="Dataloader number of workers.")
group.add_argument('--tokenizer-type', type=str,
default=None,
choices=['BertWordPieceLowerCase',
'BertWordPieceCase',
'GPT2BPETokenizer'],
help='What type of tokenizer to use.')
group.add_argument('--data-impl', type=str, default='infer',
choices=['lazy', 'cached', 'mmap', 'infer'],
help='Implementation of indexed datasets.')
group.add_argument('--reset-position-ids', action='store_true',
help='Reset posistion ids after end-of-document token.')
group.add_argument('--reset-attention-mask', action='store_true',
help='Reset self attention maske after '
'end-of-document token.')
group.add_argument('--eod-mask-loss', action='store_true',
help='Mask loss for the end of document tokens.')
return parser
def _add_autoresume_args(parser):
group = parser.add_argument_group(title='autoresume')
group.add_argument('--adlr-autoresume', action='store_true',
help='Enable autoresume on adlr cluster.')
group.add_argument('--adlr-autoresume-interval', type=int, default=1000,
help='Intervals over which check for autoresume'
'termination signal')
return parser
def _add_biencoder_args(parser):
group = parser.add_argument_group(title='biencoder')
# network size
group.add_argument('--ict-head-size', type=int, default=None,
help='Size of block embeddings to be used in ICT and '
'REALM (paper default: 128)')
group.add_argument('--biencoder-projection-dim', type=int, default=0,
help='Size of projection head used in biencoder (paper'
' default: 128)')
group.add_argument('--biencoder-shared-query-context-model', action='store_true',
help='Whether to share the parameters of the query '
'and context models or not')
# checkpointing
group.add_argument('--ict-load', type=str, default=None,
help='Directory containing an ICTBertModel checkpoint')
group.add_argument('--bert-load', type=str, default=None,
help='Directory containing an BertModel checkpoint '
'(needed to start ICT and REALM)')
# data
group.add_argument('--titles-data-path', type=str, default=None,
help='Path to titles dataset used for ICT')
group.add_argument('--query-in-block-prob', type=float, default=0.1,
help='Probability of keeping query in block for '
'ICT dataset')
group.add_argument('--use-one-sent-docs', action='store_true',
help='Whether to use one sentence documents in ICT')
group.add_argument('--evidence-data-path', type=str, default=None,
help='Path to Wikipedia Evidence frm DPR paper')
# training
group.add_argument('--retriever-report-topk-accuracies', nargs='+', type=int,
default=[], help="Which top-k accuracies to report "
"(e.g. '1 5 20')")
group.add_argument('--retriever-score-scaling', action='store_true',
help='Whether to scale retriever scores by inverse '
'square root of hidden size')
# faiss index
group.add_argument('--block-data-path', type=str, default=None,
help='Where to save/load BlockData to/from')
group.add_argument('--embedding-path', type=str, default=None,
help='Where to save/load Open-Retrieval Embedding'
' data to/from')
# indexer
group.add_argument('--indexer-batch-size', type=int, default=128,
help='How large of batches to use when doing indexing '
'jobs')
group.add_argument('--indexer-log-interval', type=int, default=1000,
help='After how many batches should the indexer '
'report progress')
return parser
def _add_vit_args(parser):
group = parser.add_argument_group(title="vit")
group.add_argument('--num-classes', type=int, default=1000,
help='num of classes in vision classificaiton task')
group.add_argument('--img-dim', type=int, default=224,
help='Image size for vision classification task')
group.add_argument('--num-channels', type=int, default=3,
help='Number of channels in input image data')
group.add_argument('--patch-dim', type=int, default=16,
help='patch dimension used in vit')
return parser
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import numpy
import torch
from apex import transformer
from apex.transformer.tensor_parallel.tests import global_vars
TEST_SUCCESS_MESSAGE = ">> passed the test :-)"
class IdentityLayer(torch.nn.Module):
def __init__(self, size, scale=1.0):
super(IdentityLayer, self).__init__()
self.weight = torch.nn.Parameter(scale * torch.randn(size))
def forward(self):
return self.weight
def set_random_seed(seed):
"""Set random seed for reproducibility."""
random.seed(seed)
numpy.random.seed(seed)
torch.manual_seed(seed)
transformer.tensor_parallel.model_parallel_cuda_manual_seed(seed)
def initialize_distributed(backend='nccl'):
"""Initialize torch.distributed."""
# Get local rank in case it is provided.
# parser = argparse.ArgumentParser()
# parser.add_argument('--local_rank', type=int, default=None,
# help='local rank passed from distributed launcher')
# args = parser.parse_args()
args = global_vars.get_args()
local_rank = args.local_rank
# Get rank and world size.
rank = int(os.getenv('RANK', '0'))
world_size = int(os.getenv("WORLD_SIZE", '1'))
print('> initializing torch.distributed with local rank: {}, '
'rank: {}, world size: {}'.format(local_rank, rank, world_size))
# Set the device id.
device = rank % torch.cuda.device_count()
if local_rank is not None:
device = local_rank
torch.cuda.set_device(device)
# Call the init process.
init_method = 'tcp://'
master_ip = os.getenv('MASTER_ADDR', 'localhost')
master_port = os.getenv('MASTER_PORT', '6000')
init_method += master_ip + ':' + master_port
torch.distributed.init_process_group(
backend=backend,
world_size=world_size,
rank=rank,
init_method=init_method)
def print_separator(message):
torch.distributed.barrier()
filler_len = (78 - len(message)) // 2
filler = '-' * filler_len
string = '\n' + filler + ' {} '.format(message) + filler
if torch.distributed.get_rank() == 0:
print(string, flush=True)
torch.distributed.barrier()
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron global variables."""
import os
import sys
import time
import torch
from apex.transformer.tensor_parallel.microbatches import build_num_microbatches_calculator
from apex.transformer.tensor_parallel.tests.arguments import parse_args
_GLOBAL_ARGS = None
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = None
_GLOBAL_TOKENIZER = None
_GLOBAL_TENSORBOARD_WRITER = None
_GLOBAL_ADLR_AUTORESUME = None
_GLOBAL_TIMERS = None
def get_args():
"""Return arguments."""
_ensure_var_is_initialized(_GLOBAL_ARGS, 'args')
return _GLOBAL_ARGS
def get_num_microbatches():
return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get()
def get_current_global_batch_size():
return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get_current_global_batch_size()
def update_num_microbatches(consumed_samples, consistency_check=True):
_GLOBAL_NUM_MICROBATCHES_CALCULATOR.update(consumed_samples,
consistency_check)
# def get_tokenizer():
# """Return tokenizer."""
# _ensure_var_is_initialized(_GLOBAL_TOKENIZER, 'tokenizer')
# return _GLOBAL_TOKENIZER
def get_tensorboard_writer():
"""Return tensorboard writer. It can be None so no need
to check if it is initialized."""
return _GLOBAL_TENSORBOARD_WRITER
def get_adlr_autoresume():
"""ADLR autoresume object. It can be None so no need
to check if it is initialized."""
return _GLOBAL_ADLR_AUTORESUME
def get_timers():
"""Return timers."""
_ensure_var_is_initialized(_GLOBAL_TIMERS, 'timers')
return _GLOBAL_TIMERS
def set_global_variables(extra_args_provider=None, args_defaults={},
ignore_unknown_args=False):
"""Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers."""
args = _parse_args(extra_args_provider=extra_args_provider,
defaults=args_defaults,
ignore_unknown_args=ignore_unknown_args)
_build_num_microbatches_calculator(args)
# if args.vocab_file:
# _ = _build_tokenizer(args)
_set_tensorboard_writer(args)
_set_adlr_autoresume(args)
_set_timers()
def _parse_args(extra_args_provider=None, defaults={},
ignore_unknown_args=False):
"""Parse entire arguments."""
global _GLOBAL_ARGS
_ensure_var_is_not_initialized(_GLOBAL_ARGS, 'args')
_GLOBAL_ARGS = parse_args(extra_args_provider=extra_args_provider,
defaults=defaults,
ignore_unknown_args=ignore_unknown_args)
return _GLOBAL_ARGS
def _build_num_microbatches_calculator(args):
global _GLOBAL_NUM_MICROBATCHES_CALCULATOR
_ensure_var_is_not_initialized(_GLOBAL_NUM_MICROBATCHES_CALCULATOR,
'num microbatches calculator')
_GLOBAL_NUM_MICROBATCHES_CALCULATOR = build_num_microbatches_calculator(
args)
# def _build_tokenizer(args):
# """Initialize tokenizer."""
# global _GLOBAL_TOKENIZER
# _ensure_var_is_not_initialized(_GLOBAL_TOKENIZER, 'tokenizer')
# _GLOBAL_TOKENIZER = build_tokenizer(args)
# return _GLOBAL_TOKENIZER
# def rebuild_tokenizer(args):
# global _GLOBAL_TOKENIZER
# _GLOBAL_TOKENIZER = None
# return _build_tokenizer(args)
def _set_tensorboard_writer(args):
"""Set tensorboard writer."""
global _GLOBAL_TENSORBOARD_WRITER
_ensure_var_is_not_initialized(_GLOBAL_TENSORBOARD_WRITER,
'tensorboard writer')
if hasattr(args, 'tensorboard_dir') and \
args.tensorboard_dir and args.rank == (args.world_size - 1):
try:
from torch.utils.tensorboard import SummaryWriter
print('> setting tensorboard ...')
_GLOBAL_TENSORBOARD_WRITER = SummaryWriter(
log_dir=args.tensorboard_dir,
max_queue=args.tensorboard_queue_size)
except ModuleNotFoundError:
print('WARNING: TensorBoard writing requested but is not '
'available (are you using PyTorch 1.1.0 or later?), '
'no TensorBoard logs will be written.', flush=True)
def _set_adlr_autoresume(args):
"""Initialize ADLR autoresume."""
global _GLOBAL_ADLR_AUTORESUME
_ensure_var_is_not_initialized(_GLOBAL_ADLR_AUTORESUME, 'adlr autoresume')
if args.adlr_autoresume:
if args.rank == 0:
print('enabling autoresume ...', flush=True)
sys.path.append(os.environ.get('SUBMIT_SCRIPTS', '.'))
try:
from userlib.auto_resume import AutoResume
except BaseException:
print('ADLR autoresume is not available, exiting ...')
sys.exit()
_GLOBAL_ADLR_AUTORESUME = AutoResume
def _set_timers():
"""Initialize timers."""
global _GLOBAL_TIMERS
_ensure_var_is_not_initialized(_GLOBAL_TIMERS, 'timers')
_GLOBAL_TIMERS = Timers()
def _ensure_var_is_initialized(var, name):
"""Make sure the input variable is not None."""
assert var is not None, '{} is not initialized.'.format(name)
def _ensure_var_is_not_initialized(var, name):
"""Make sure the input variable is not None."""
assert var is None, '{} is already initialized.'.format(name)
class _Timer:
"""Timer."""
def __init__(self, name):
self.name_ = name
self.elapsed_ = 0.0
self.started_ = False
self.start_time = time.time()
def start(self):
"""Start the timer."""
assert not self.started_, 'timer has already been started'
torch.cuda.synchronize()
self.start_time = time.time()
self.started_ = True
def stop(self):
"""Stop the timer."""
assert self.started_, 'timer is not started'
torch.cuda.synchronize()
self.elapsed_ += (time.time() - self.start_time)
self.started_ = False
def reset(self):
"""Reset timer."""
self.elapsed_ = 0.0
self.started_ = False
def elapsed(self, reset=True):
"""Calculate the elapsed time."""
started_ = self.started_
# If the timing in progress, end it first.
if self.started_:
self.stop()
# Get the elapsed time.
elapsed_ = self.elapsed_
# Reset the elapsed time
if reset:
self.reset()
# If timing was in progress, set it back.
if started_:
self.start()
return elapsed_
class Timers:
"""Group of timers."""
def __init__(self):
self.timers = {}
def __call__(self, name):
if name not in self.timers:
self.timers[name] = _Timer(name)
return self.timers[name]
def write(self, names, writer, iteration, normalizer=1.0, reset=False):
"""Write timers to a tensorboard writer"""
# currently when using add_scalars,
# torch.utils.add_scalars makes each timer its own run, which
# polutes the runs list, so we just add each as a scalar
assert normalizer > 0.0
for name in names:
value = self.timers[name].elapsed(reset=reset) / normalizer
writer.add_scalar(name + '-time', value, iteration)
def log(self, names, normalizer=1.0, reset=True):
"""Log a group of timers."""
assert normalizer > 0.0
string = 'time (ms)'
for name in names:
elapsed_time = self.timers[name].elapsed(
reset=reset) * 1000.0 / normalizer
string += ' | {}: {:.2f}'.format(name, elapsed_time)
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == (
torch.distributed.get_world_size() - 1):
print(string, flush=True)
else:
print(string, flush=True)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
def divide(numerator, denominator):
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility(numerator, denominator)
return numerator // denominator
def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False):
"""Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
"""
# Get the size and dimension.
last_dim = tensor.dim() - 1
last_dim_size = divide(tensor.size()[last_dim], num_partitions)
# Split.
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
# Note: torch.split does not create contiguous tensors by default.
if contiguous_split_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list
class VocabUtility:
"""Split the vocabulary into `world_size` chunks amd return the
first and last index of the vocabulary belonging to the `rank`
partition: Note that indecies in [fist, last)"""
@staticmethod
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size):
index_f = rank * per_partition_vocab_size
index_l = index_f + per_partition_vocab_size
return index_f, index_l
@staticmethod
def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
per_partition_vocab_size = divide(global_vocab_size, world_size)
return VocabUtility.vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size)
...@@ -33,6 +33,13 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda( ...@@ -33,6 +33,13 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python); at::optional<bool> per_tensor_python);
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_scale_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
float scale,
at::optional<bool> per_tensor_python);
void multi_tensor_lamb_stage1_cuda( void multi_tensor_lamb_stage1_cuda(
int chunk_size, int chunk_size,
at::Tensor noop_flag, at::Tensor noop_flag,
...@@ -121,6 +128,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -121,6 +128,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"out = a*x + b*y for a list of contiguous tensors"); "out = a*x + b*y for a list of contiguous tensors");
m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda, m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda,
"Computes L2 norm for a list of contiguous tensors"); "Computes L2 norm for a list of contiguous tensors");
m.def("multi_tensor_l2norm_scale", &multi_tensor_l2norm_scale_cuda,
"Computes L2 norm for a list of contiguous tensors and does scaling");
m.def("multi_tensor_lamb_stage1_cuda", &multi_tensor_lamb_stage1_cuda, m.def("multi_tensor_lamb_stage1_cuda", &multi_tensor_lamb_stage1_cuda,
"Computes update part of LAMB optimizer"); "Computes update part of LAMB optimizer");
m.def("multi_tensor_lamb_stage2_cuda", &multi_tensor_lamb_stage2_cuda, m.def("multi_tensor_lamb_stage2_cuda", &multi_tensor_lamb_stage2_cuda,
......
#include <torch/extension.h>
#include <torch/torch.h>
#include <vector>
#include <stdio.h>
template <typename T>
int linear_bias_forward_cuda(at::Tensor input, T *weight, at::Tensor bias, int in_features, int batch_size, int out_features, at::Tensor output, void *lt_workspace);
template <typename T>
int linear_bias_backward_cuda(T *input, T *weight, T *d_output, int in_features, int batch_size, int out_features, T *d_weight, T *d_bias, T *d_input, void *lt_workspace);
template <typename T>
int linear_gelu_linear_forward_cuda(T *input, T *weight1, T *bias1, T *weight2, T *bias2, int in_features, int hidden_features, int batch_size, int out_features, T *output1, T *output2, T *gelu_in, void *lt_workspace) ;
template <typename T>
int linear_gelu_linear_backward_cuda(T *input, T *gelu_in, T *output1, T *weight1, T *weight2, T *d_output1, T *d_output2, int in_features, int batch_size, int hidden_features, int out_features, T *d_weight1, T *d_weight2, T *d_bias1, T *d_bias2, T *d_input, void *lt_workspace);
at::Tensor linear_bias_forward(at::Tensor input, at::Tensor weight, at::Tensor bias) {
auto batch_size = input.size(0);
auto in_features = input.size(1);
int out_features = weight.size(0);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
// create output/workspace tensor
auto out = at::empty({batch_size, out_features}, input.type());
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto lt_workspace = at::empty({1 << 22}, input.type());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "linear_bias_forward", [&] {
scalar_t* w_ptr = weight.data_ptr<scalar_t>();
scalar_t* b_ptr = bias.data_ptr<scalar_t>();
auto result = linear_bias_forward_cuda<scalar_t>(
input,
w_ptr,
bias,
in_features,
batch_size,
out_features,
out,
//out.data_ptr<scalar_t>(),
// reserved_space.data_ptr<scalar_t>(),
(void*) (lt_workspace.data_ptr<scalar_t>()));
});
return {out};
}
std::vector<at::Tensor> linear_bias_backward(at::Tensor input, at::Tensor weight, at::Tensor d_output) {
auto batch_size = input.size(0);
auto in_features = input.size(1);
int out_features = weight.size(0);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
// create output/workspace tensor
auto d_weight = at::empty({out_features, in_features}, input.type());
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600
auto d_bias = d_output.view({-1, out_features}).sum(0, false);
#else
auto d_bias = at::empty({out_features}, input.type());
#endif
auto d_input = at::empty({batch_size, in_features}, input.type());
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto lt_workspace = at::empty({1 << 22}, input.type());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "linear_bias_backward", [&] {
scalar_t* w_ptr = weight.data_ptr<scalar_t>();
scalar_t* d_b_ptr = d_bias.data_ptr<scalar_t>();
auto result = linear_bias_backward_cuda<scalar_t>(
input.data_ptr<scalar_t>(),
w_ptr,
d_output.data_ptr<scalar_t>(),
in_features,
batch_size,
out_features,
d_weight.data_ptr<scalar_t>(),
d_bias.data_ptr<scalar_t>(),
d_input.data_ptr<scalar_t>(),
// reserved_space.data_ptr<scalar_t>(),
(void*) (lt_workspace.data_ptr<scalar_t>()));
});
return {d_input, d_weight, d_bias};
}
std::vector<at::Tensor> linear_gelu_linear_forward(at::Tensor input, at::Tensor weight1, at::Tensor bias1, at::Tensor weight2, at::Tensor bias2) {
auto batch_size = input.size(0);
auto in_features = input.size(1);
int hidden_features = weight1.size(0);
int out_features = weight2.size(0);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
// create output/workspace tensor
auto output1 = at::empty({batch_size, hidden_features}, input.type());
auto gelu_in = at::empty({batch_size, hidden_features}, input.type());
auto output2 = at::empty({batch_size, out_features}, input.type());
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto lt_workspace = at::empty({1 << 22}, input.type());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "linear_gelu_linear_forward", [&] {
scalar_t* w1_ptr = weight1.data_ptr<scalar_t>();
scalar_t* b1_ptr = bias1.data_ptr<scalar_t>();
scalar_t* w2_ptr = weight2.data_ptr<scalar_t>();
scalar_t* b2_ptr = bias2.data_ptr<scalar_t>();
auto result = linear_gelu_linear_forward_cuda<scalar_t>(
input.data_ptr<scalar_t>(),
w1_ptr,
b1_ptr,
w2_ptr,
b2_ptr,
in_features,
hidden_features,
batch_size,
out_features,
output1.data_ptr<scalar_t>(),
output2.data_ptr<scalar_t>(),
gelu_in.data_ptr<scalar_t>(),
// reserved_space.data_ptr<scalar_t>(),
(void*) (lt_workspace.data_ptr<scalar_t>()));
});
return {output1, output2, gelu_in};
}
std::vector<at::Tensor> linear_gelu_linear_backward(at::Tensor input, at::Tensor gelu_in, at::Tensor output1, at::Tensor weight1, at::Tensor weight2, at::Tensor d_output2) {
auto batch_size = input.size(0);
auto in_features = input.size(1);
int hidden_features = weight1.size(0);
int out_features = weight2.size(0);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
// create output/workspace tensor
auto d_weight1 = at::empty({hidden_features, in_features}, input.type());
auto d_weight2 = at::empty({out_features, hidden_features}, input.type());
auto d_bias1 = at::empty({hidden_features}, input.type());
auto d_bias2 = at::empty({out_features}, input.type());
auto d_input = at::empty({batch_size, in_features}, input.type());
auto d_output1 = at::empty({batch_size, hidden_features}, input.type());
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto lt_workspace = at::empty({1 << 22}, input.type());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "linear_bias_backward", [&] {
//scalar_t* w_ptr = weight.data_ptr<scalar_t>();
//scalar_t* d_b_ptr = d_bias.data_ptr<scalar_t>();
auto result = linear_gelu_linear_backward_cuda<scalar_t>(
input.data_ptr<scalar_t>(),
gelu_in.data_ptr<scalar_t>(),
output1.data_ptr<scalar_t>(),
weight1.data_ptr<scalar_t>(),
weight2.data_ptr<scalar_t>(),
d_output1.data_ptr<scalar_t>(),
d_output2.data_ptr<scalar_t>(),
in_features,
batch_size,
hidden_features,
out_features,
d_weight1.data_ptr<scalar_t>(),
d_weight2.data_ptr<scalar_t>(),
d_bias1.data_ptr<scalar_t>(),
d_bias2.data_ptr<scalar_t>(),
d_input.data_ptr<scalar_t>(),
// reserved_space.data_ptr<scalar_t>(),
(void*) (lt_workspace.data_ptr<scalar_t>()));
});
return {d_input, d_weight1, d_bias1, d_weight2, d_bias2};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("linear_bias_forward", &linear_bias_forward, "linear bias forward");
m.def("linear_bias_backward", &linear_bias_backward, "linear bias backward");
m.def("linear_gelu_linear_forward", &linear_gelu_linear_forward, "linear gelu linear forward");
m.def("linear_gelu_linear_backward", &linear_gelu_linear_backward, "linear gelu linear backward");
}
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <torch/torch.h>
/* Includes, cuda */
#include <cublas_v2.h>
#include <cuda_runtime.h>
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
// includes cublaslt
#include <cublasLt.h>
#endif
// FP64 Wrapper around cublas GEMMEx
cublasStatus_t gemm_bias(
cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float* alpha,
double* A,
int lda,
double* B,
int ldb,
const float* beta,
double* C,
int ldc) {
#ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
rocblas_datatype_f64_r,
lda,
B,
rocblas_datatype_f64_r,
ldb,
beta,
C,
rocblas_datatype_f64_r,
ldc,
C,
rocblas_datatype_f64_r,
ldc,
rocblas_datatype_f64_r,
rocblas_gemm_algo_standard,
0,
0);
#else
return cublasGemmEx(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
CUDA_R_64F,
lda,
B,
CUDA_R_64F,
ldb,
beta,
C,
CUDA_R_64F,
ldc,
CUDA_R_64F,
CUBLAS_GEMM_DEFAULT);
#endif
}
// FP32 Wrapper around cublas GEMMEx
cublasStatus_t gemm_bias(
cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float* alpha,
float* A,
int lda,
float* B,
int ldb,
const float* beta,
float* C,
int ldc) {
#ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
rocblas_datatype_f32_r,
lda,
B,
rocblas_datatype_f32_r,
ldb,
beta,
C,
rocblas_datatype_f32_r,
ldc,
C,
rocblas_datatype_f32_r,
ldc,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard,
0,
0);
#else
return cublasGemmEx(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
CUDA_R_32F,
lda,
B,
CUDA_R_32F,
ldb,
beta,
C,
CUDA_R_32F,
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT);
#endif
}
// FP16 Tensor core wrapper around cublas GEMMEx
cublasStatus_t gemm_bias(
cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float* alpha,
at::Half* A,
int lda,
at::Half* B,
int ldb,
const float* beta,
at::Half* C,
int ldc) {
#ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
rocblas_datatype_f16_r,
lda,
B,
rocblas_datatype_f16_r,
ldb,
beta,
C,
rocblas_datatype_f16_r,
ldc,
C,
rocblas_datatype_f16_r,
ldc,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard,
0,
0);
#else
return cublasGemmEx(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
CUDA_R_16F,
lda,
B,
CUDA_R_16F,
ldb,
beta,
C,
CUDA_R_16F,
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif
}
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
int gemm_bias_lt(
cublasLtHandle_t ltHandle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float *alpha, /* host pointer */
at::Half* A,
int lda,
at::Half* B,
int ldb,
const float *beta, /* host pointer */
at::Half* C,
int ldc,
void *workspace,
size_t workspaceSize,
cudaStream_t stream,
bool use_bias,
const void* bias) {
cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
cublasLtMatmulDescOpaque_t operationDesc = {};
cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
cublasLtMatmulPreferenceOpaque_t preference = {};
int returnedResults = 0;
cublasLtMatmulHeuristicResult_t heuristicResult = {};
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;
// Create operation descriptor; see cublasLtMatmulDescAttributes_t
// for details about defaults; here we just set the transforms for
// A and B.
status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
if (use_bias) {
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));
if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP;
}
epilogue = CUBLASLT_EPILOGUE_BIAS;
}
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP;
}
// Create matrix descriptors. Not setting any extra attributes.
status = cublasLtMatrixLayoutInit(
&Adesc, CUDA_R_16F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatrixLayoutInit(
&Bdesc, CUDA_R_16F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16F, m, n, ldc);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
// Create preference handle; In general, extra attributes can be
// used here to disable tensor ops or to make sure algo selected
// will work with badly aligned A, B, C. However, for simplicity
// here we assume A,B,C are always well aligned (e.g., directly
// come from cudaMalloc)
status = cublasLtMatmulPreferenceInit(&preference);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatmulPreferenceSetAttribute(
&preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
// We just need the best available heuristic to try and run matmul.
// There is no guarantee that this will work. For example, if A is
// badly aligned, you can request more (e.g. 32) algos and try to
// run them one by one until something works.
status = cublasLtMatmulAlgoGetHeuristic(
ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
if (returnedResults == 0) {
status = CUBLAS_STATUS_NOT_SUPPORTED;
goto CLEANUP;
}
status = cublasLtMatmul(ltHandle,
&operationDesc,
alpha,
A,
&Adesc,
B,
&Bdesc,
beta,
C,
&Cdesc,
C,
&Cdesc,
//&heuristicResult.algo,
NULL,
workspace,
workspaceSize,
stream);
CLEANUP:
// Descriptors are no longer needed as all GPU work was already
// enqueued.
return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
}
int gemm_bias_lt(
cublasLtHandle_t ltHandle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float *alpha, /* host pointer */
double* A,
int lda,
double* B,
int ldb,
const float *beta, /* host pointer */
double* C,
int ldc,
void *workspace,
size_t workspaceSize,
cudaStream_t stream,
bool use_bias,
const void* bias) {
return 1;
}
int gemm_bias_lt(
cublasLtHandle_t ltHandle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float *alpha, /* host pointer */
float *A,
int lda,
float *B,
int ldb,
const float *beta, /* host pointer */
float *C,
int ldc,
void *workspace,
size_t workspaceSize,
cudaStream_t stream,
bool use_bias,
const void* bias) {
cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
cublasLtMatmulDescOpaque_t operationDesc = {};
cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
cublasLtMatmulPreferenceOpaque_t preference = {};
int returnedResults = 0;
cublasLtMatmulHeuristicResult_t heuristicResult = {};
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;
// Create operation descriptor; see cublasLtMatmulDescAttributes_t
// for details about defaults; here we just set the transforms for
// A and B.
status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
if (use_bias) {
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));
if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP;
}
epilogue = CUBLASLT_EPILOGUE_BIAS;
}
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP;
}
// Create matrix descriptors. Not setting any extra attributes.
status = cublasLtMatrixLayoutInit(
&Adesc, CUDA_R_32F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatrixLayoutInit(
&Bdesc, CUDA_R_32F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_32F, m, n, ldc);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
// Create preference handle; In general, extra attributes can be
// used here to disable tensor ops or to make sure algo selected
// will work with badly aligned A, B, C. However, for simplicity
// here we assume A,B,C are always well aligned (e.g., directly
// come from cudaMalloc)
status = cublasLtMatmulPreferenceInit(&preference);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatmulPreferenceSetAttribute(
&preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
// We just need the best available heuristic to try and run matmul.
// There is no guarantee that this will work. For example, if A is
// badly aligned, you can request more (e.g. 32) algos and try to
// run them one by one until something works.
status = cublasLtMatmulAlgoGetHeuristic(
ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
if (returnedResults == 0) {
status = CUBLAS_STATUS_NOT_SUPPORTED;
goto CLEANUP;
}
status = cublasLtMatmul(ltHandle,
&operationDesc,
alpha,
A,
&Adesc,
B,
&Bdesc,
beta,
C,
&Cdesc,
C,
&Cdesc,
&heuristicResult.algo,
workspace,
workspaceSize,
stream);
CLEANUP:
// Descriptors are no longer needed as all GPU work was already
// enqueued.
return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
}
int gemm_bias_gelu_lt(
cublasLtHandle_t ltHandle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float *alpha, /* host pointer */
at::Half* A,
int lda,
at::Half* B,
int ldb,
const float *beta, /* host pointer */
at::Half* C,
int64_t ldc,
void *workspace,
size_t workspaceSize,
cudaStream_t stream,
bool use_bias,
const void* gelu_in,
const void* bias) {
cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
cublasLtMatmulDescOpaque_t operationDesc = {};
cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
cublasLtMatmulPreferenceOpaque_t preference = {};
int returnedResults = 0;
cublasLtMatmulHeuristicResult_t heuristicResult = {};
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_GELU_AUX;
// Create operation descriptor; see cublasLtMatmulDescAttributes_t
// for details about defaults; here we just set the transforms for
// A and B.
status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &gelu_in, sizeof(gelu_in));
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc));
if (use_bias) {
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));
if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP;
}
epilogue = CUBLASLT_EPILOGUE_GELU_AUX_BIAS;
}
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP;
}
// Create matrix descriptors. Not setting any extra attributes.
status = cublasLtMatrixLayoutInit(
&Adesc, CUDA_R_16F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatrixLayoutInit(
&Bdesc, CUDA_R_16F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16F, m, n, ldc);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
// Create preference handle; In general, extra attributes can be
// used here to disable tensor ops or to make sure algo selected
// will work with badly aligned A, B, C. However, for simplicity
// here we assume A,B,C are always well aligned (e.g., directly
// come from cudaMalloc)
status = cublasLtMatmulPreferenceInit(&preference);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatmulPreferenceSetAttribute(
&preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
// We just need the best available heuristic to try and run matmul.
// There is no guarantee that this will work. For example, if A is
// badly aligned, you can request more (e.g. 32) algos and try to
// run them one by one until something works.
status = cublasLtMatmulAlgoGetHeuristic(
ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
if (returnedResults == 0) {
status = CUBLAS_STATUS_NOT_SUPPORTED;
goto CLEANUP;
}
status = cublasLtMatmul(ltHandle,
&operationDesc,
alpha,
A,
&Adesc,
B,
&Bdesc,
beta,
C,
&Cdesc,
C,
&Cdesc,
//&heuristicResult.algo,
NULL,
workspace,
workspaceSize,
stream);
CLEANUP:
// Descriptors are no longer needed as all GPU work was already
// enqueued.
return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
}
int gemm_bias_gelu_lt(
cublasLtHandle_t ltHandle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float *alpha, /* host pointer */
double* A,
int lda,
double* B,
int ldb,
const float *beta, /* host pointer */
double* C,
int ldc,
void *workspace,
size_t workspaceSize,
cudaStream_t stream,
bool use_bias,
const void *gelu_in,
const void* bias) {
return 1;
}
int gemm_bias_gelu_lt(
cublasLtHandle_t ltHandle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float *alpha, /* host pointer */
float *A,
int lda,
float *B,
int ldb,
const float *beta, /* host pointer */
float *C,
int64_t ldc,
void *workspace,
size_t workspaceSize,
cudaStream_t stream,
bool use_bias,
const void* gelu_in,
const void* bias) {
cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
cublasLtMatmulDescOpaque_t operationDesc = {};
cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
cublasLtMatmulPreferenceOpaque_t preference = {};
int returnedResults = 0;
cublasLtMatmulHeuristicResult_t heuristicResult = {};
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_GELU_AUX;
// Create operation descriptor; see cublasLtMatmulDescAttributes_t
// for details about defaults; here we just set the transforms for
// A and B.
status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &gelu_in, sizeof(gelu_in));
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc));
if (use_bias) {
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));
if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP;
}
epilogue = CUBLASLT_EPILOGUE_GELU_AUX_BIAS;
}
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP;
}
// Create matrix descriptors. Not setting any extra attributes.
status = cublasLtMatrixLayoutInit(
&Adesc, CUDA_R_32F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatrixLayoutInit(
&Bdesc, CUDA_R_32F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_32F, m, n, ldc);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
// Create preference handle; In general, extra attributes can be
// used here to disable tensor ops or to make sure algo selected
// will work with badly aligned A, B, C. However, for simplicity
// here we assume A,B,C are always well aligned (e.g., directly
// come from cudaMalloc)
status = cublasLtMatmulPreferenceInit(&preference);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatmulPreferenceSetAttribute(
&preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
// We just need the best available heuristic to try and run matmul.
// There is no guarantee that this will work. For example, if A is
// badly aligned, you can request more (e.g. 32) algos and try to
// run them one by one until something works.
status = cublasLtMatmulAlgoGetHeuristic(
ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
if (returnedResults == 0) {
status = CUBLAS_STATUS_NOT_SUPPORTED;
goto CLEANUP;
}
status = cublasLtMatmul(ltHandle,
&operationDesc,
alpha,
A,
&Adesc,
B,
&Bdesc,
beta,
C,
&Cdesc,
C,
&Cdesc,
//&heuristicResult.algo,
NULL,
workspace,
workspaceSize,
stream);
CLEANUP:
// Descriptors are no longer needed as all GPU work was already
// enqueued.
return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
}
int gemm_bgradb_lt(
cublasLtHandle_t ltHandle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float *alpha, /* host pointer */
at::Half* A,
int lda,
at::Half* B,
int ldb,
const float *beta, /* host pointer */
at::Half* C,
int ldc,
void *workspace,
size_t workspaceSize,
cudaStream_t stream,
bool use_bias,
const void* bgrad) {
cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
cublasLtMatmulDescOpaque_t operationDesc = {};
cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
cublasLtMatmulPreferenceOpaque_t preference = {};
int returnedResults = 0;
cublasLtMatmulHeuristicResult_t heuristicResult = {};
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;
// Create operation descriptor; see cublasLtMatmulDescAttributes_t
// for details about defaults; here we just set the transforms for
// A and B.
status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
if (use_bias) {
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad));
if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP;
}
epilogue = CUBLASLT_EPILOGUE_BGRADB;
}
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP;
}
// Create matrix descriptors. Not setting any extra attributes.
status = cublasLtMatrixLayoutInit(
&Adesc, CUDA_R_16F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatrixLayoutInit(
&Bdesc, CUDA_R_16F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16F, m, n, ldc);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
// Create preference handle; In general, extra attributes can be
// used here to disable tensor ops or to make sure algo selected
// will work with badly aligned A, B, C. However, for simplicity
// here we assume A,B,C are always well aligned (e.g., directly
// come from cudaMalloc)
status = cublasLtMatmulPreferenceInit(&preference);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatmulPreferenceSetAttribute(
&preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
// We just need the best available heuristic to try and run matmul.
// There is no guarantee that this will work. For example, if A is
// badly aligned, you can request more (e.g. 32) algos and try to
// run them one by one until something works.
status = cublasLtMatmulAlgoGetHeuristic(
ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
if (returnedResults == 0) {
status = CUBLAS_STATUS_NOT_SUPPORTED;
goto CLEANUP;
}
status = cublasLtMatmul(ltHandle,
&operationDesc,
alpha,
A,
&Adesc,
B,
&Bdesc,
beta,
C,
&Cdesc,
C,
&Cdesc,
//&heuristicResult.algo,
NULL,
workspace,
workspaceSize,
stream);
CLEANUP:
// Descriptors are no longer needed as all GPU work was already
// enqueued.
return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
}
int gemm_bgradb_lt(
cublasLtHandle_t ltHandle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float *alpha, /* host pointer */
double* A,
int lda,
double* B,
int ldb,
const float *beta, /* host pointer */
double* C,
int ldc,
void *workspace,
size_t workspaceSize,
cudaStream_t stream,
bool use_bias,
const void* bgrad) {
return 1;
}
int gemm_bgradb_lt(
cublasLtHandle_t ltHandle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float *alpha, /* host pointer */
float *A,
int lda,
float *B,
int ldb,
const float *beta, /* host pointer */
float *C,
int ldc,
void *workspace,
size_t workspaceSize,
cudaStream_t stream,
bool use_bias,
const void* bgrad) {
cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
cublasLtMatmulDescOpaque_t operationDesc = {};
cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
cublasLtMatmulPreferenceOpaque_t preference = {};
int returnedResults = 0;
cublasLtMatmulHeuristicResult_t heuristicResult = {};
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;
// Create operation descriptor; see cublasLtMatmulDescAttributes_t
// for details about defaults; here we just set the transforms for
// A and B.
status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
if (use_bias) {
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad));
if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP;
}
epilogue = CUBLASLT_EPILOGUE_BGRADB;
}
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP;
}
// Create matrix descriptors. Not setting any extra attributes.
status = cublasLtMatrixLayoutInit(
&Adesc, CUDA_R_32F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatrixLayoutInit(
&Bdesc, CUDA_R_32F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_32F, m, n, ldc);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
// Create preference handle; In general, extra attributes can be
// used here to disable tensor ops or to make sure algo selected
// will work with badly aligned A, B, C. However, for simplicity
// here we assume A,B,C are always well aligned (e.g., directly
// come from cudaMalloc)
status = cublasLtMatmulPreferenceInit(&preference);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatmulPreferenceSetAttribute(
&preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
// We just need the best available heuristic to try and run matmul.
// There is no guarantee that this will work. For example, if A is
// badly aligned, you can request more (e.g. 32) algos and try to
// run them one by one until something works.
status = cublasLtMatmulAlgoGetHeuristic(
ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
if (returnedResults == 0) {
status = CUBLAS_STATUS_NOT_SUPPORTED;
goto CLEANUP;
}
status = cublasLtMatmul(ltHandle,
&operationDesc,
alpha,
A,
&Adesc,
B,
&Bdesc,
beta,
C,
&Cdesc,
C,
&Cdesc,
&heuristicResult.algo,
workspace,
workspaceSize,
stream);
CLEANUP:
// Descriptors are no longer needed as all GPU work was already
// enqueued.
return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
}
int gemm_dgelu_bgradb_lt(
cublasLtHandle_t ltHandle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float *alpha, /* host pointer */
at::Half* A,
int lda,
at::Half* B,
int ldb,
const float *beta, /* host pointer */
at::Half* C,
int64_t ldc,
void *workspace,
size_t workspaceSize,
cudaStream_t stream,
const void *gelu_in,
const void *bgrad) {
cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
cublasLtMatmulDescOpaque_t operationDesc = {};
cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
cublasLtMatmulPreferenceOpaque_t preference = {};
int returnedResults = 0;
cublasLtMatmulHeuristicResult_t heuristicResult = {};
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DGELU_BGRAD;
// Create operation descriptor; see cublasLtMatmulDescAttributes_t
// for details about defaults; here we just set the transforms for
// A and B.
status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad));
if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP;
}
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &gelu_in, sizeof(gelu_in));
if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP;
}
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc));
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP;
}
// Create matrix descriptors. Not setting any extra attributes.
status = cublasLtMatrixLayoutInit(
&Adesc, CUDA_R_16F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatrixLayoutInit(
&Bdesc, CUDA_R_16F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16F, m, n, ldc);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
// Create preference handle; In general, extra attributes can be
// used here to disable tensor ops or to make sure algo selected
// will work with badly aligned A, B, C. However, for simplicity
// here we assume A,B,C are always well aligned (e.g., directly
// come from cudaMalloc)
status = cublasLtMatmulPreferenceInit(&preference);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatmulPreferenceSetAttribute(
&preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
// We just need the best available heuristic to try and run matmul.
// There is no guarantee that this will work. For example, if A is
// badly aligned, you can request more (e.g. 32) algos and try to
// run them one by one until something works.
status = cublasLtMatmulAlgoGetHeuristic(
ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
if (returnedResults == 0) {
status = CUBLAS_STATUS_NOT_SUPPORTED;
goto CLEANUP;
}
status = cublasLtMatmul(ltHandle,
&operationDesc,
alpha,
A,
&Adesc,
B,
&Bdesc,
beta,
C,
&Cdesc,
C,
&Cdesc,
//&heuristicResult.algo,
NULL,
workspace,
workspaceSize,
stream);
CLEANUP:
// Descriptors are no longer needed as all GPU work was already
// enqueued.
return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
}
int gemm_dgelu_bgradb_lt(
cublasLtHandle_t ltHandle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float *alpha, /* host pointer */
double *A,
int lda,
double *B,
int ldb,
const float *beta, /* host pointer */
double *C,
int ldc,
void *workspace,
size_t workspaceSize,
cudaStream_t stream,
const void *gelu_in,
const void *bgrad) {
return 1;
}
int gemm_dgelu_bgradb_lt(
cublasLtHandle_t ltHandle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float *alpha, /* host pointer */
float *A,
int lda,
float *B,
int ldb,
const float *beta, /* host pointer */
float *C,
int64_t ldc,
void *workspace,
size_t workspaceSize,
cudaStream_t stream,
const void *gelu_in,
const void *bgrad) {
cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
cublasLtMatmulDescOpaque_t operationDesc = {};
cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
cublasLtMatmulPreferenceOpaque_t preference = {};
int returnedResults = 0;
cublasLtMatmulHeuristicResult_t heuristicResult = {};
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DGELU_BGRAD;
// Create operation descriptor; see cublasLtMatmulDescAttributes_t
// for details about defaults; here we just set the transforms for
// A and B.
status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad));
if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP;
}
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &gelu_in, sizeof(gelu_in));
if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP;
}
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc));
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP;
}
// Create matrix descriptors. Not setting any extra attributes.
status = cublasLtMatrixLayoutInit(
&Adesc, CUDA_R_32F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatrixLayoutInit(
&Bdesc, CUDA_R_32F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_32F, m, n, ldc);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
// Create preference handle; In general, extra attributes can be
// used here to disable tensor ops or to make sure algo selected
// will work with badly aligned A, B, C. However, for simplicity
// here we assume A,B,C are always well aligned (e.g., directly
// come from cudaMalloc)
status = cublasLtMatmulPreferenceInit(&preference);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatmulPreferenceSetAttribute(
&preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
// We just need the best available heuristic to try and run matmul.
// There is no guarantee that this will work. For example, if A is
// badly aligned, you can request more (e.g. 32) algos and try to
// run them one by one until something works.
status = cublasLtMatmulAlgoGetHeuristic(
ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
if (returnedResults == 0) {
status = CUBLAS_STATUS_NOT_SUPPORTED;
goto CLEANUP;
}
status = cublasLtMatmul(ltHandle,
&operationDesc,
alpha,
A,
&Adesc,
B,
&Bdesc,
beta,
C,
&Cdesc,
C,
&Cdesc,
//&heuristicResult.algo,
NULL,
workspace,
workspaceSize,
stream);
CLEANUP:
// Descriptors are no longer needed as all GPU work was already
// enqueued.
return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
}
#endif
template <typename T>
int linear_bias_forward_cuda(at::Tensor input, T *weight, at::Tensor bias, int in_features, int batch_size, int out_features, at::Tensor output, void *lt_workspace) {
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
// Get the stream from cublas handle to reuse for biasReLU kernel.
cudaStream_t stream;
cublasGetStream(handle, &stream);
const float alpha = 1.0;
const float beta_zero = 0.0;
const float beta_one = 1.0;
int status = 1;
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
status = gemm_bias_lt(
(cublasLtHandle_t)handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
out_features,
batch_size,
in_features,
&alpha, /* host pointer */
weight,
in_features,
input.data_ptr<T>(),
in_features,
&beta_zero, /* host pointer */
output.data_ptr<T>(),
out_features,
lt_workspace,
1 << 22,
stream,
true,
static_cast<const void*>(bias.data_ptr<T>()));
#endif
if (status != 0){
output.copy_(bias);
status = gemm_bias(
handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
out_features,
batch_size,
in_features,
&alpha,
weight,
in_features,
input.data_ptr<T>(),
in_features,
&beta_one,
output.data_ptr<T>(),
out_features);
}
return status;
}
template <typename T>
int linear_bias_backward_cuda(T *input, T *weight, T *d_output, int in_features, int batch_size, int out_features, T *d_weight, T *d_bias, T *d_input, void *lt_workspace) {
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
// Get the stream from cublas handle to reuse for biasReLU kernel.
cudaStream_t stream;
cublasGetStream(handle, &stream);
const float alpha = 1.0;
const float beta_zero = 0.0;
const float beta_one = 1.0;
int status = 1;
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
status = gemm_bgradb_lt(
(cublasLtHandle_t)handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
in_features,
out_features,
batch_size,
&alpha, /* host pointer */
input,
in_features,
d_output,
out_features,
&beta_zero, /* host pointer */
d_weight,
in_features,
lt_workspace,
1 << 22,
stream,
true,
static_cast<const void*>(d_bias));
#endif
if (status != 0){
status = gemm_bias(
handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
in_features,
out_features,
batch_size,
&alpha,
input,
in_features,
d_output,
out_features,
&beta_zero,
d_weight,
in_features);
}
status = gemm_bias(
handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
in_features,
batch_size,
out_features,
&alpha,
weight,
in_features,
d_output,
out_features,
&beta_zero,
d_input,
in_features);
return status;
}
template <typename T>
int linear_gelu_linear_forward_cuda(T *input, T *weight1, T *bias1, T *weight2, T *bias2, int in_features, int hidden_features, int batch_size, int out_features, T *output1, T *output2, T *gelu_in, void *lt_workspace) {
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
// Get the stream from cublas handle to reuse for biasReLU kernel.
cudaStream_t stream;
cublasGetStream(handle, &stream);
const float alpha = 1.0;
const float beta_zero = 0.0;
int status = 1;
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
status = gemm_bias_gelu_lt(
(cublasLtHandle_t)handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
hidden_features,
batch_size,
in_features,
&alpha, /* host pointer */
weight1,
in_features,
input,
in_features,
&beta_zero, /* host pointer */
output1,
hidden_features,
lt_workspace,
1 << 22,
stream,
true,
static_cast<const void*>(gelu_in),
static_cast<const void*>(bias1));
status = gemm_bias_lt(
(cublasLtHandle_t)handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
out_features,
batch_size,
hidden_features,
&alpha, /* host pointer */
weight2,
hidden_features,
output1,
hidden_features,
&beta_zero, /* host pointer */
output2,
out_features,
lt_workspace,
1 << 22,
stream,
true,
static_cast<const void*>(bias2));
return status;
#else
return 1;
#endif
}
template <typename T>
int linear_gelu_linear_backward_cuda(T *input, T *gelu_in, T *output1, T *weight1, T *weight2, T *d_output1, T *d_output2, int in_features, int batch_size, int hidden_features, int out_features, T *d_weight1, T *d_weight2, T *d_bias1, T *d_bias2, T *d_input, void *lt_workspace) {
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
// Get the stream from cublas handle to reuse for biasReLU kernel.
cudaStream_t stream;
cublasGetStream(handle, &stream);
const float alpha = 1.0;
const float beta_zero = 0.0;
const float beta_one = 1.0;
int status = 1;
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
//wgrad for first gemm
status = gemm_bgradb_lt(
(cublasLtHandle_t)handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
hidden_features,
out_features,
batch_size,
&alpha, /* host pointer */
output1,
hidden_features,
d_output2,
out_features,
&beta_zero, /* host pointer */
d_weight2,
hidden_features,
lt_workspace,
1 << 22,
stream,
true,
static_cast<const void*>(d_bias2));
//dgrad for second GEMM
status = gemm_dgelu_bgradb_lt(
(cublasLtHandle_t)handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
hidden_features,
batch_size,
out_features,
&alpha, /* host pointer */
weight2,
hidden_features,
d_output2,
out_features,
&beta_zero, /* host pointer */
d_output1,
hidden_features,
lt_workspace,
1 << 22,
stream,
static_cast<const void*>(gelu_in),
static_cast<const void*>(d_bias1));
//wgrad for the first GEMM
status = gemm_bias(
handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
in_features,
hidden_features,
batch_size,
&alpha,
input,
in_features,
d_output1,
hidden_features,
&beta_zero,
d_weight1,
in_features);
//dgrad for the first GEMM
status = gemm_bias(
handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
in_features,
batch_size,
hidden_features,
&alpha,
weight1,
in_features,
d_output1,
hidden_features,
&beta_zero,
d_input,
in_features);
#endif
return status;
}
template int linear_bias_forward_cuda<at::Half>(at::Tensor input, at::Half *weight, at::Tensor bias, int in_features, int batch_size, int out_features, at::Tensor output, void *lt_workspace);
template int linear_bias_forward_cuda<float>(at::Tensor input, float *weight, at::Tensor bias, int in_features, int batch_size, int out_features, at::Tensor output, void *lt_workspace);
template int linear_bias_forward_cuda<double>(at::Tensor input, double *weight, at::Tensor bias, int in_features, int batch_size, int out_features, at::Tensor output, void *lt_workspace);
template int linear_bias_backward_cuda<at::Half>(at::Half *input, at::Half *weight, at::Half *d_output, int in_features, int batch_size, int out_features, at::Half *d_weight, at::Half *d_bias, at::Half *d_input, void *lt_workspace) ;
template int linear_bias_backward_cuda<float>(float *input, float *weight, float *d_output, int in_features, int batch_size, int out_features, float *d_weight, float *d_bias, float *d_input, void *lt_workspace) ;
template int linear_bias_backward_cuda<double>(double *input, double *weight, double *d_output, int in_features, int batch_size, int out_features, double *d_weight, double *d_bias, double *d_input, void *lt_workspace) ;
template int linear_gelu_linear_forward_cuda<at::Half>(at::Half *input, at::Half *weight1, at::Half *bias1, at::Half *weight2, at::Half *bias2, int in_features, int hidden_features, int batch_size, int out_features, at::Half *output1, at::Half *output2, at::Half *gelu_in, void *lt_workspace) ;
template int linear_gelu_linear_forward_cuda<float>(float *input, float *weight1, float *bias1, float *weight2, float *bias2, int in_features, int hidden_features, int batch_size, int out_features, float *output1, float *output2, float *gelu_in, void *lt_workspace);
template int linear_gelu_linear_forward_cuda<double>(double *input, double *weight1, double *bias1, double *weight2, double *bias2, int in_features, int hidden_features, int batch_size, int out_features, double *output1, double *output2, double *gelu_in, void *lt_workspace) ;
template int linear_gelu_linear_backward_cuda<at::Half>(at::Half *input, at::Half *gelu_in, at::Half *output1, at::Half *weight1, at::Half *weight2, at::Half *d_output1, at::Half *d_output2, int in_features, int batch_size, int hidden_features, int out_features, at::Half *d_weight1, at::Half *d_weight2, at::Half *d_bias1, at::Half *d_bias2, at::Half *d_input, void *lt_workspace);
template int linear_gelu_linear_backward_cuda<float>(float *input, float *gelu_in, float *output1, float *weight1, float *weight2, float *d_output1, float *d_output2, int in_features, int batch_size, int hidden_features, int out_features, float *d_weight1, float *d_weight2, float *d_bias1, float *d_bias2, float *d_input, void *lt_workspace);
template int linear_gelu_linear_backward_cuda<double>(double *input, double *gelu_in, double *output1, double *weight1, double *weight2, double *d_output1, double *d_output2, int in_features, int batch_size, int hidden_features, int out_features, double *d_weight1, double *d_weight2, double *d_bias1, double *d_bias2, double *d_input, void *lt_workspace);
...@@ -130,13 +130,13 @@ std::vector<at::Tensor> layer_norm( ...@@ -130,13 +130,13 @@ std::vector<at::Tensor> layer_norm(
int n1,n2; int n1,n2;
check_args(input,normalized_shape,n1,n2); check_args(input,normalized_shape,n1,n2);
at::Tensor output = at::empty_like(input); at::Tensor output = at::empty_like(input);
at::Tensor mean = at::empty({n1}, input.options().dtype((input.scalar_type()==at::ScalarType::Half || at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half || input.scalar_type()==at::ScalarType::BFloat16 ? at::ScalarType::Float : input.scalar_type()));
input.scalar_type()==at::ScalarType::BFloat16) ? at::ScalarType::Float : input.scalar_type()));
at::Tensor invvar = at::empty_like(mean); at::Tensor invvar = at::empty_like(mean);
cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2, cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2,
normalized_shape,NULL,NULL,epsilon); normalized_shape,NULL,NULL,epsilon);
return {output, mean, invvar}; return {output, mean, invvar};
} }
std::vector<at::Tensor> layer_norm_affine( std::vector<at::Tensor> layer_norm_affine(
at::Tensor input, at::Tensor input,
#ifdef VERSION_GE_1_1 #ifdef VERSION_GE_1_1
...@@ -153,14 +153,35 @@ std::vector<at::Tensor> layer_norm_affine( ...@@ -153,14 +153,35 @@ std::vector<at::Tensor> layer_norm_affine(
int n1,n2; int n1,n2;
check_args(input,normalized_shape,gamma,beta,n1,n2); check_args(input,normalized_shape,gamma,beta,n1,n2);
at::Tensor output = at::empty_like(input); at::Tensor output = at::empty_like(input);
at::Tensor mean = at::empty({n1}, input.options().dtype((input.scalar_type()==at::ScalarType::Half || const auto stats_dtype = (input.scalar_type() == at::ScalarType::Half || input.scalar_type() == at::ScalarType::BFloat16) ? at::ScalarType::Float : input.scalar_type();
input.scalar_type()==at::ScalarType::BFloat16) ? at::ScalarType::Float : input.scalar_type())); at::Tensor mean = at::empty({n1}, input.options().dtype(stats_dtype));
at::Tensor invvar = at::empty_like(mean); at::Tensor invvar = at::empty_like(mean);
cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2, cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2,
normalized_shape,&gamma,&beta,epsilon); normalized_shape,&gamma,&beta,epsilon);
return {output, mean, invvar}; return {output, mean, invvar};
} }
std::vector<at::Tensor> layer_norm_affine_mixed_dtypes(
at::Tensor input,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor gamma,
at::Tensor beta,
double epsilon) {
CHECK_INPUT(input);
int n1, n2;
check_args(input, normalized_shape, n1, n2);
at::Tensor output = at::empty_like(input, gamma.options().dtype(gamma.scalar_type()));
at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type() == at::ScalarType::Half || input.scalar_type() == at::ScalarType::BFloat16 ? at::ScalarType::Float : input.scalar_type()));
at::Tensor invvar = at::empty_like(mean);
cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2,
normalized_shape, &gamma, &beta, epsilon);
return {output, mean, invvar};
}
void cuda_layer_norm_gradient( void cuda_layer_norm_gradient(
at::Tensor* dout, at::Tensor* dout,
at::Tensor* mean, at::Tensor* mean,
...@@ -204,6 +225,7 @@ at::Tensor layer_norm_gradient( ...@@ -204,6 +225,7 @@ at::Tensor layer_norm_gradient(
&grad_input,NULL,NULL); &grad_input,NULL,NULL);
return grad_input; return grad_input;
} }
std::vector<at::Tensor> layer_norm_gradient_affine( std::vector<at::Tensor> layer_norm_gradient_affine(
at::Tensor dout, at::Tensor dout,
at::Tensor mean, at::Tensor mean,
...@@ -239,5 +261,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -239,5 +261,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &layer_norm, "LayerNorm forward (CUDA)"); m.def("forward", &layer_norm, "LayerNorm forward (CUDA)");
m.def("backward_affine", &layer_norm_gradient_affine, "LayerNorm backward (CUDA)"); m.def("backward_affine", &layer_norm_gradient_affine, "LayerNorm backward (CUDA)");
m.def("backward", &layer_norm_gradient, "LayerNorm backward (CUDA)"); m.def("backward", &layer_norm_gradient, "LayerNorm backward (CUDA)");
m.def("forward_affine_mixed_dtypes", &layer_norm_affine_mixed_dtypes, "LayerNorm forward with mixed dtypes (CUDA) compatible with Megatron's implementation");
} }
...@@ -56,7 +56,7 @@ void cuWelfordMuSigma2( ...@@ -56,7 +56,7 @@ void cuWelfordMuSigma2(
const int i1, const int i1,
U& mu, U& mu,
U& sigma2, U& sigma2,
U* buf) U* buf)
{ {
// Assumptions: // Assumptions:
// 1) blockDim.x == warpSize // 1) blockDim.x == warpSize
...@@ -140,7 +140,7 @@ void cuWelfordMuSigma2( ...@@ -140,7 +140,7 @@ void cuWelfordMuSigma2(
const int i1, const int i1,
float& mu, float& mu,
float& sigma2, float& sigma2,
float* buf) float* buf)
{ {
// Assumptions: // Assumptions:
// 1) blockDim.x == warpSize // 1) blockDim.x == warpSize
...@@ -172,8 +172,8 @@ void cuWelfordMuSigma2( ...@@ -172,8 +172,8 @@ void cuWelfordMuSigma2(
for (; l+7 < n2; l+=8*numx) { for (; l+7 < n2; l+=8*numx) {
for (int k = 0; k < 8; k+=2) { for (int k = 0; k < 8; k+=2) {
float2 curr = __half22float2(*((__half2*)(lvals+l+k))); float2 curr = __half22float2(*((__half2*)(lvals+l+k)));
cuWelfordOnlineSum<float>(curr.x,mu,sigma2,count); cuWelfordOnlineSum(curr.x,mu,sigma2,count);
cuWelfordOnlineSum<float>(curr.y,mu,sigma2,count); cuWelfordOnlineSum(curr.y,mu,sigma2,count);
} }
} }
for (; l < n2; ++l) { for (; l < n2; ++l) {
...@@ -282,18 +282,18 @@ struct SharedMemory <double> ...@@ -282,18 +282,18 @@ struct SharedMemory <double>
}; };
} }
template<typename T, typename U> __global__ template<typename T, typename U, typename V> __device__
void cuApplyLayerNorm( void cuApplyLayerNorm_(
T* __restrict__ output_vals, V* __restrict__ output_vals,
U* __restrict__ mean, U* __restrict__ mean,
U* __restrict__ invvar, U* __restrict__ invvar,
const T* __restrict__ vals, const T* __restrict__ vals,
const int n1, const int n1,
const int n2, const int n2,
const U epsilon, const U epsilon,
const T* __restrict__ gamma, const V* __restrict__ gamma,
const T* __restrict__ beta const V* __restrict__ beta
) )
{ {
// Assumptions: // Assumptions:
// 1) blockDim.x == warpSize // 1) blockDim.x == warpSize
...@@ -305,29 +305,47 @@ void cuApplyLayerNorm( ...@@ -305,29 +305,47 @@ void cuApplyLayerNorm(
U mu,sigma2; U mu,sigma2;
cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf); cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf);
const T* lvals = vals + i1*n2; const T* lvals = vals + i1*n2;
T* ovals = output_vals + i1*n2; V* ovals = output_vals + i1*n2;
U c_invvar = rsqrt(sigma2 + epsilon); U c_invvar = rsqrt(sigma2 + epsilon);
const int numx = blockDim.x * blockDim.y; const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x; const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
if (gamma != NULL && beta != NULL) { if (gamma != NULL && beta != NULL) {
for (int i = thrx; i < n2; i+=numx) { for (int i = thrx; i < n2; i+=numx) {
U curr = static_cast<U>(lvals[i]); U curr = static_cast<U>(lvals[i]);
ovals[i] = gamma[i] * static_cast<T>(c_invvar * (curr - mu)) + beta[i]; ovals[i] = gamma[i] * static_cast<V>(c_invvar * (curr - mu)) + beta[i];
} }
} else { } else {
for (int i = thrx; i < n2; i+=numx) { for (int i = thrx; i < n2; i+=numx) {
U curr = static_cast<U>(lvals[i]); U curr = static_cast<U>(lvals[i]);
ovals[i] = static_cast<T>(c_invvar * (curr - mu)); ovals[i] = static_cast<V>(c_invvar * (curr - mu));
} }
} }
if (threadIdx.x == 0 && threadIdx.y == 0) { if (threadIdx.x == 0 && threadIdx.y == 0) {
mean[i1] = mu; mean[i1] = mu;
invvar[i1] = c_invvar; invvar[i1] = c_invvar;
} }
__syncthreads();
} }
} }
template<typename T, typename U> __device__ template<typename T, typename U, typename V=T> __global__
void cuApplyLayerNorm(
V* __restrict__ output_vals,
U* __restrict__ mean,
U* __restrict__ invvar,
const T* __restrict__ vals,
const int n1,
const int n2,
const U epsilon,
const V* __restrict__ gamma,
const V* __restrict__ beta
)
{
cuApplyLayerNorm_<T, U, V>(output_vals, mean, invvar, vals, n1, n2, epsilon, gamma, beta);
}
template<typename T, typename U, typename V> __device__
void cuLoadWriteStridedInputs( void cuLoadWriteStridedInputs(
const int i1_block, const int i1_block,
const int thr_load_row_off, const int thr_load_row_off,
...@@ -337,7 +355,7 @@ void cuLoadWriteStridedInputs( ...@@ -337,7 +355,7 @@ void cuLoadWriteStridedInputs(
U* warp_buf1, U* warp_buf1,
U* warp_buf2, U* warp_buf2,
const T* input, const T* input,
const T* dout, const V* dout,
const int i1_end, const int i1_end,
const int n2, const int n2,
const U* __restrict__ mean, const U* __restrict__ mean,
...@@ -354,9 +372,9 @@ void cuLoadWriteStridedInputs( ...@@ -354,9 +372,9 @@ void cuLoadWriteStridedInputs(
int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
if (i2<n2) { if (i2<n2) {
U curr_input = static_cast<U>(input[load_idx]); U curr_input = static_cast<U>(input[load_idx]);
U curr_dout = static_cast<U>(dout[load_idx]); U curr_dout = static_cast<U>(dout[load_idx]);
warp_buf1[write_idx] = curr_dout; warp_buf1[write_idx] = curr_dout;
warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar; warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar;
} else { } else {
warp_buf1[write_idx] = U(0); warp_buf1[write_idx] = U(0);
warp_buf2[write_idx] = U(0); warp_buf2[write_idx] = U(0);
...@@ -371,7 +389,7 @@ void cuLoadWriteStridedInputs( ...@@ -371,7 +389,7 @@ void cuLoadWriteStridedInputs(
} }
} }
template<typename T, typename U> __device__ template<typename T, typename U, typename V> __device__
void cuLoadAddStridedInputs( void cuLoadAddStridedInputs(
const int i1_block, const int i1_block,
const int thr_load_row_off, const int thr_load_row_off,
...@@ -381,7 +399,7 @@ void cuLoadAddStridedInputs( ...@@ -381,7 +399,7 @@ void cuLoadAddStridedInputs(
U* warp_buf1, U* warp_buf1,
U* warp_buf2, U* warp_buf2,
const T* input, const T* input,
const T* dout, const V* dout,
const int i1_end, const int i1_end,
const int n2, const int n2,
const U* __restrict__ mean, const U* __restrict__ mean,
...@@ -398,17 +416,17 @@ void cuLoadAddStridedInputs( ...@@ -398,17 +416,17 @@ void cuLoadAddStridedInputs(
int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
if (i2<n2) { if (i2<n2) {
U curr_input = static_cast<U>(input[load_idx]); U curr_input = static_cast<U>(input[load_idx]);
U curr_dout = static_cast<U>(dout[load_idx]); U curr_dout = static_cast<U>(dout[load_idx]);
warp_buf1[write_idx] += curr_dout; warp_buf1[write_idx] += curr_dout;
warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar; warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar;
} }
} }
} }
} }
template<typename T, typename U> __global__ template<typename T, typename U, typename V> __global__
void cuComputePartGradGammaBeta( void cuComputePartGradGammaBeta(
const T* __restrict__ dout, const V* __restrict__ dout,
const T* __restrict__ input, const T* __restrict__ input,
const int n1, const int n1,
const int n2, const int n2,
...@@ -455,11 +473,11 @@ void cuComputePartGradGammaBeta( ...@@ -455,11 +473,11 @@ void cuComputePartGradGammaBeta(
for (int offset = blockDim.y/2; offset > 1; offset /= 2) { for (int offset = blockDim.y/2; offset > 1; offset /= 2) {
if (threadIdx.y < offset) { if (threadIdx.y < offset) {
int row1 = threadIdx.y; int row1 = threadIdx.y;
int row2 = threadIdx.y + offset; int row2 = threadIdx.y + offset;
int idx1 = row1*row_stride + threadIdx.x; int idx1 = row1*row_stride + threadIdx.x;
int idx2 = row2*row_stride + threadIdx.x; int idx2 = row2*row_stride + threadIdx.x;
warp_buf1[idx1] += warp_buf1[idx2]; warp_buf1[idx1] += warp_buf1[idx2];
warp_buf2[idx1] += warp_buf2[idx2]; warp_buf2[idx1] += warp_buf2[idx2];
} }
__syncthreads(); __syncthreads();
} }
...@@ -474,19 +492,19 @@ void cuComputePartGradGammaBeta( ...@@ -474,19 +492,19 @@ void cuComputePartGradGammaBeta(
} }
} }
template<typename T, typename U> __global__ template<typename U, typename V> __global__
void cuComputeGradGammaBeta( void cuComputeGradGammaBeta(
const U* part_grad_gamma, const U* part_grad_gamma,
const U* part_grad_beta, const U* part_grad_beta,
const int part_size, const int part_size,
const int n1, const int n1,
const int n2, const int n2,
T* grad_gamma, V* grad_gamma,
T* grad_beta) V* grad_beta)
{ {
// sum partial gradients for gamma and beta // sum partial gradients for gamma and beta
SharedMemory<U> shared; SharedMemory<U> shared;
U* buf = shared.getPointer(); U* buf = shared.getPointer();
int i2 = blockIdx.x * blockDim.x + threadIdx.x; int i2 = blockIdx.x * blockDim.x + threadIdx.x;
if (i2 < n2) { if (i2 < n2) {
// each warp does sequential reductions until reduced part_size is num_warps // each warp does sequential reductions until reduced part_size is num_warps
...@@ -525,16 +543,16 @@ void cuComputeGradGammaBeta( ...@@ -525,16 +543,16 @@ void cuComputeGradGammaBeta(
} }
} }
template<typename T, typename U> __global__ template<typename T, typename U, typename V> __global__
void cuComputeGradInput( void cuComputeGradInput(
const T* __restrict__ dout, const V* __restrict__ dout,
const T* __restrict__ input, const T* __restrict__ input,
const int n1, const int n1,
const int n2, const int n2,
const U* __restrict__ mean, const U* __restrict__ mean,
const U* __restrict__ invvar, const U* __restrict__ invvar,
U epsilon, U epsilon,
const T* gamma, const V* gamma,
T* grad_input) T* grad_input)
{ {
for (int i1=blockIdx.y; i1 < n1; i1 += gridDim.y) { for (int i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
...@@ -543,7 +561,7 @@ void cuComputeGradInput( ...@@ -543,7 +561,7 @@ void cuComputeGradInput(
const U c_mean = mean[i1]; const U c_mean = mean[i1];
const U c_invvar = invvar[i1]; const U c_invvar = invvar[i1];
const T* k_input = input + i1*n2; const T* k_input = input + i1*n2;
const T* k_dout = dout + i1*n2; const V* k_dout = dout + i1*n2;
const int numx = blockDim.x * blockDim.y; const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x; const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
if (gamma != NULL) { if (gamma != NULL) {
...@@ -587,7 +605,7 @@ void cuComputeGradInput( ...@@ -587,7 +605,7 @@ void cuComputeGradInput(
// inter-warp reductions // inter-warp reductions
if (blockDim.y > 1) { if (blockDim.y > 1) {
SharedMemory<U> shared; SharedMemory<U> shared;
U* buf = shared.getPointer(); U* buf = shared.getPointer();
for (int offset = blockDim.y/2; offset > 0; offset /= 2) { for (int offset = blockDim.y/2; offset > 0; offset /= 2) {
// upper half of warps write to shared // upper half of warps write to shared
if (threadIdx.y >= offset && threadIdx.y < 2*offset) { if (threadIdx.y >= offset && threadIdx.y < 2*offset) {
...@@ -612,7 +630,7 @@ void cuComputeGradInput( ...@@ -612,7 +630,7 @@ void cuComputeGradInput(
if (threadIdx.y !=0) { if (threadIdx.y !=0) {
sum_loss1 = buf[2*threadIdx.x]; sum_loss1 = buf[2*threadIdx.x];
sum_loss2 = buf[2*threadIdx.x+1]; sum_loss2 = buf[2*threadIdx.x+1];
} }
} }
// all threads now have the two sums over l // all threads now have the two sums over l
U fH = (U)n2; U fH = (U)n2;
...@@ -639,38 +657,34 @@ void cuComputeGradInput( ...@@ -639,38 +657,34 @@ void cuComputeGradInput(
k_grad_input[l] = static_cast<T>(f_grad_input); k_grad_input[l] = static_cast<T>(f_grad_input);
} }
} }
// prevent race where buf is written again before reads are done
__syncthreads();
} }
} }
template<typename T, typename U> template<typename T, typename U, typename V=T>
void HostApplyLayerNorm( void HostApplyLayerNorm(
T* output, V* output,
U* mean, U* mean,
U* invvar, U* invvar,
const T* input, const T* input,
int n1, int n1,
int n2, int n2,
double epsilon, double epsilon,
const T* gamma, const V* gamma,
const T* beta const V* beta
) )
{ {
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
const dim3 threads(32,4,1); const dim3 threads(32,4,1);
const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1); const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);
int nshared = int nshared =
threads.y > 1 ? threads.y > 1 ?
threads.y*sizeof(U)+(threads.y/2)*sizeof(U) : threads.y*sizeof(U)+(threads.y/2)*sizeof(U) :
0; 0;
cuApplyLayerNorm<<<blocks, threads, nshared, stream>>>( cuApplyLayerNorm<<<blocks, threads, nshared, stream>>>(
output, output, mean, invvar, input, n1, n2, U(epsilon), gamma, beta);
mean,
invvar,
input,
n1,n2,
U(epsilon),
gamma,beta);
} }
void cuda_layer_norm( void cuda_layer_norm(
...@@ -690,34 +704,35 @@ void cuda_layer_norm( ...@@ -690,34 +704,35 @@ void cuda_layer_norm(
double epsilon) double epsilon)
{ {
using namespace at; using namespace at;
DISPATCH_DOUBLE_FLOAT_AND_HALF_AND_BFLOAT16(input->scalar_type(), 0, "layer_norm_cuda_kernel", DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
using accscalar_t = at::acc_type<scalar_t_0, true>; input->scalar_type(), output->scalar_type(), "layer_norm_cuda_kernel",
HostApplyLayerNorm( using accscalar_t = at::acc_type<scalar_t_in, true>;
output->DATA_PTR<scalar_t_0>(), HostApplyLayerNorm<scalar_t_in, accscalar_t, scalar_t_out>(
mean->DATA_PTR<accscalar_t>(), output->DATA_PTR<scalar_t_out>(),
invvar->DATA_PTR<accscalar_t>(), mean->DATA_PTR<accscalar_t>(),
input->DATA_PTR<scalar_t_0>(), invvar->DATA_PTR<accscalar_t>(),
n1,n2, input->DATA_PTR<scalar_t_in>(),
epsilon, n1,n2,
gamma != NULL ? gamma->DATA_PTR<scalar_t_0>() : NULL, epsilon,
beta != NULL ? beta->DATA_PTR<scalar_t_0>() : NULL); gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
beta != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL);
) )
} }
template<typename T, typename U> template<typename T, typename U=float, typename V=T>
void HostLayerNormGradient( void HostLayerNormGradient(
const T* dout, const V* dout,
const U* mean, const U* mean,
const U* invvar, const U* invvar,
at::Tensor* input, at::Tensor* input,
int n1, int n1,
int n2, int n2,
const T* gamma, const V* gamma,
const T* beta, const V* beta,
double epsilon, double epsilon,
T* grad_input, T* grad_input,
T* grad_gamma, V* grad_gamma,
T* grad_beta V* grad_beta
) )
{ {
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
...@@ -730,8 +745,13 @@ void HostLayerNormGradient( ...@@ -730,8 +745,13 @@ void HostLayerNormGradient(
const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1); const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1);
const int nshared2_b = threads2.x * threads2.y * sizeof(U); const int nshared2_b = threads2.x * threads2.y * sizeof(U);
const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
at::Tensor part_grad_gamma = at::empty({part_size,n2}, input->options().dtype((input->scalar_type()==at::ScalarType::Half || // note (mkozuki): I can hard code part_grad_gamma's dtype as float given that
input->scalar_type()==at::ScalarType::BFloat16) ? at::ScalarType::Float : input->scalar_type())); // the `cuda_layer_norm_gradient` doesn't support double.
const auto part_grad_dtype =
(input->scalar_type() == at::ScalarType::Half || input->scalar_type() == at::ScalarType::BFloat16) ?
at::ScalarType::Float :
input->scalar_type();
at::Tensor part_grad_gamma = at::empty({part_size,n2}, input->options().dtype(part_grad_dtype));
at::Tensor part_grad_beta = at::empty_like(part_grad_gamma); at::Tensor part_grad_beta = at::empty_like(part_grad_gamma);
cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>( cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>(
dout, dout,
...@@ -794,21 +814,23 @@ void cuda_layer_norm_gradient( ...@@ -794,21 +814,23 @@ void cuda_layer_norm_gradient(
at::Tensor* grad_beta) at::Tensor* grad_beta)
{ {
using namespace at; using namespace at;
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(input->scalar_type(), 0, "cuComputeGradInput", // we can do away with `accscalar_t` as there're only three dtypes: fp32, fp16, bf16
using accscalar_t = at::acc_type<scalar_t_0, true>; DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
HostLayerNormGradient( input->scalar_type(), gamma == NULL ? input->scalar_type() : gamma->scalar_type(), "cuComputeGradInput",
dout->DATA_PTR<scalar_t_0>(), using accscalar_t = at::acc_type<scalar_t_in, true>;
mean->DATA_PTR<accscalar_t>(), HostLayerNormGradient(
invvar->DATA_PTR<accscalar_t>(), dout->DATA_PTR<scalar_t_out>(),
input, mean->DATA_PTR<accscalar_t>(),
n1,n2, invvar->DATA_PTR<accscalar_t>(),
input,
n1,n2,
// TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta // TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
// if gamma Tensor is NULL on input. // if gamma Tensor is NULL on input.
gamma != NULL ? gamma->DATA_PTR<scalar_t_0>() : NULL, gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
gamma != NULL ? beta->DATA_PTR<scalar_t_0>() : NULL, gamma != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL,
epsilon, epsilon,
grad_input->DATA_PTR<scalar_t_0>(), grad_input->DATA_PTR<scalar_t_in>(),
gamma != NULL ? grad_gamma->DATA_PTR<scalar_t_0>() : NULL, gamma != NULL ? grad_gamma->DATA_PTR<scalar_t_out>() : NULL,
gamma != NULL ? grad_beta->DATA_PTR<scalar_t_0>() : NULL); gamma != NULL ? grad_beta->DATA_PTR<scalar_t_out>() : NULL);
) )
} }
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace multihead_attn {
namespace fused_softmax {
namespace scaled_masked_softmax {
torch::Tensor fwd_cuda(
torch::Tensor const& input,
torch::Tensor const& mask,
float scale_factor);
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor);
int get_batch_per_block_cuda(
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads);
torch::Tensor fwd(
torch::Tensor const& input,
torch::Tensor const& mask,
float scale_factor) {
AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM(mask.dim() == 4, "expected 4D tensor");
return fwd_cuda(input, mask, scale_factor);
}
torch::Tensor bwd(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor) {
AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return bwd_cuda(output_grads, softmax_results, scale_factor);
}
int get_batch_per_block(
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads) {
return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads);
}
} // end namespace scaled_masked_softmax
} // end namespace fused_softmax
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward",
&multihead_attn::fused_softmax::scaled_masked_softmax::fwd,
"Self Multihead Attention scaled, time masked softmax -- Forward.");
m.def("backward",
&multihead_attn::fused_softmax::scaled_masked_softmax::bwd,
"Self Multihead Attention scaled, time masked softmax -- Backward.");
m.def("get_batch_per_block",
&multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block,
"Return Batch per block size."
);
}
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <assert.h>
#include <cuda_fp16.h>
#include <cfloat>
#include <limits>
#include <stdint.h>
#include <cuda_fp16.h>
#include <c10/macros/Macros.h>
namespace {
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); }
template <>
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst, const c10::Half *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); }
template <>
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; }
template <>
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); }
int log2_ceil(int value) {
int log2_value = 0;
while ((1 << log2_value) < value) ++log2_value;
return log2_value;
}
template<typename T>
struct Add {
__device__ __forceinline__ T operator()(T a, T b) const {
return a + b;
}
};
template<typename T>
struct Max {
__device__ __forceinline__ T operator()(T a, T b) const {
return a < b ? b : a;
}
};
template <typename T>
__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
{
#if CUDA_VERSION >= 9000
return __shfl_xor_sync(mask, value, laneMask, width);
#else
return __shfl_xor(value, laneMask, width);
#endif
}
template <typename acc_t, int WARP_BATCH, int WARP_SIZE, template<typename> class ReduceOp>
__device__ __forceinline__ void warp_reduce(acc_t* sum) {
ReduceOp<acc_t> r;
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
sum[i] = r(sum[i], b);
}
}
}
/*
* Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling
* 2) Explicit masking
*/
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_masked_softmax_warp_forward(
output_t *dst,
const input_t *src,
const uint8_t *mask,
const acc_t scale,
int micro_batch_size,
int element_count,
int pad_batches)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_forward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH;
int pad_first_batch = 0;
if (pad_batches != 1) { // bert style
pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH;
} else { // gpt2 style
pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
}
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x;
src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
// load data from global memory
acc_t elements[WARP_BATCH][WARP_ITERATIONS];
input_t temp_data[ELEMENTS_PER_LDG_STG];
uint8_t temp_mask[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
int itr_idx = i*element_count+it*WARP_SIZE;
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx);
copy_vector<uint8_t, ELEMENTS_PER_LDG_STG>(temp_mask, mask + itr_idx);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (temp_mask[element] != 1) {
elements[i][it + element] = (acc_t)temp_data[element] * scale;
} else {
elements[i][it + element] = -10000.0;
}
}
} else {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
}
}
}
}
// compute max_value
acc_t max_value[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
max_value[i] = elements[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
acc_t sum[WARP_BATCH] { 0.0f };
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
elements[i][it] = std::exp((elements[i][it] - max_value[i]));
sum[i] += elements[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result
output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = elements[i][it + element] / sum[i];
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);
} else {
break;
}
}
}
}
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_masked_softmax_warp_backward(
output_t *gradInput,
input_t *grad,
const input_t *output,
acc_t scale,
int micro_batch_size,
int element_count)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_backward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int local_batches = micro_batch_size - first_batch;
if (local_batches > WARP_BATCH)
local_batches = WARP_BATCH;
// there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x;
// the first element to process by the current thread
int thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
grad += thread_offset;
output += thread_offset;
gradInput += thread_offset;
// load data from global memory
acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
input_t temp_grad[ELEMENTS_PER_LDG_STG];
input_t temp_output[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
int batch_element_count = (i >= local_batches) ? 0 : element_count;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < batch_element_count) {
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_grad, grad + i * element_count + it * WARP_SIZE);
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_output, output + i * element_count + it * WARP_SIZE);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
output_reg[i][it + element] = (acc_t)temp_output[element];
}
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element];
}
}
}
}
acc_t sum[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
sum[i] = grad_reg[i][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
sum[i] += grad_reg[i][it];
}
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
// store result
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
if (i >= local_batches)
break;
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) {
int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
if (element_index < element_count) {
// compute gradients
output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i]));
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count + it * WARP_SIZE, out);
}
}
}
}
} // end of anonymous namespace
int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads){
int log2_elements = log2_ceil(key_seq_len);
const int next_power_of_two = 1 << log2_elements;
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
return batches_per_block;
}
template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_masked_softmax_forward(
output_t *dst,
const input_t *src,
const uint8_t *mask,
const input_t scale,
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads,
int pad_batches)
{
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048 );
if (key_seq_len == 0) {
return;
} else {
int log2_elements = log2_ceil(key_seq_len);
const int next_power_of_two = 1 << log2_elements;
int batch_count = batches * attn_heads * query_seq_len;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0);
dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches);
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 1: // 2
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 2: // 4
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 3: // 8
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 4: // 16
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 5: // 32
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 6: // 64
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 7: // 128
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 8: // 256
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 9: // 512
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 10: // 1024
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
case 11: // 2048
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
break;
default:
break;
}
}
}
template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_masked_softmax_backward(
output_t *grad_input,
input_t *grad,
const input_t *output,
const acc_t scale,
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads)
{
TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 2048 );
if (key_seq_len == 0) {
return;
} else {
int log2_elements = log2_ceil(key_seq_len);
const int next_power_of_two = 1 << log2_elements;
int batch_count = batches * attn_heads * query_seq_len;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = batch_count/batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 1: // 2
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 2: // 4
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 3: // 8
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 4: // 16
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 5: // 32
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 6: // 64
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 7: // 128
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 8: // 256
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 9: // 512
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 10: // 1024
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
case 11: // 2048
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
break;
default:
break;
}
}
}
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "scaled_masked_softmax.h"
#include "type_shim.h"
namespace multihead_attn {
namespace fused_softmax {
namespace scaled_masked_softmax {
int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){
return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads);
}
torch::Tensor fwd_cuda(
torch::Tensor const& input,
torch::Tensor const& mask,
float scale_factor)
{
// input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = input.size(0);
const int pad_batches = mask.size(0);
const int attn_heads = input.size(1);
const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3);
TORCH_INTERNAL_ASSERT(key_seq_len <= 2048);
TORCH_INTERNAL_ASSERT(query_seq_len > 1);
TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches);
TORCH_INTERNAL_ASSERT(mask.size(1) == 1);
TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len);
TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len);
// Output
auto act_options = input.options().requires_grad(false);
torch::Tensor softmax_results =
torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
// Softmax Intermediate Result Ptr
void* input_ptr = static_cast<void*>(input.data_ptr());
void* mask_ptr = static_cast<void*>(mask.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
DISPATCH_HALF_AND_BFLOAT(
input.scalar_type(),
"dispatch_scaled_masked_softmax_forward",
dispatch_scaled_masked_softmax_forward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(softmax_results_ptr),
reinterpret_cast<const scalar_t*>(input_ptr),
reinterpret_cast<const uint8_t*>(mask_ptr),
scale_factor,
query_seq_len,
key_seq_len,
batches,
attn_heads,
pad_batches);
);
return softmax_results;
}
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads_,
torch::Tensor const& softmax_results_,
float scale_factor) {
auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous();
//output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = output_grads.size(0);
const int attn_heads = output_grads.size(1);
const int query_seq_len = output_grads.size(2);
const int key_seq_len = output_grads.size(3);
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
//Softmax Grad
DISPATCH_HALF_AND_BFLOAT(
output_grads_.scalar_type(),
"dispatch_scaled_masked_softmax_backward",
dispatch_scaled_masked_softmax_backward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
scale_factor,
query_seq_len,
key_seq_len,
batches,
attn_heads);
);
//backward pass is completely in-place
return output_grads;
}
}
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment