Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
f79993d9
Commit
f79993d9
authored
Oct 15, 2021
by
hubertlu-tw
Browse files
Merge remote-tracking branch 'upstream/master' into IFU-master-2021-10-15
parents
297ab210
1d5f7e55
Changes
117
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
5041 additions
and
152 deletions
+5041
-152
apex/transformer/tensor_parallel/tests/__init__.py
apex/transformer/tensor_parallel/tests/__init__.py
+0
-0
apex/transformer/tensor_parallel/tests/arguments.py
apex/transformer/tensor_parallel/tests/arguments.py
+766
-0
apex/transformer/tensor_parallel/tests/commons.py
apex/transformer/tensor_parallel/tests/commons.py
+87
-0
apex/transformer/tensor_parallel/tests/global_vars.py
apex/transformer/tensor_parallel/tests/global_vars.py
+260
-0
apex/transformer/tensor_parallel/utils.py
apex/transformer/tensor_parallel/utils.py
+64
-0
csrc/amp_C_frontend.cpp
csrc/amp_C_frontend.cpp
+9
-0
csrc/fused_dense.cpp
csrc/fused_dense.cpp
+192
-0
csrc/fused_dense_cuda.cu
csrc/fused_dense_cuda.cu
+1437
-0
csrc/layer_norm_cuda.cpp
csrc/layer_norm_cuda.cpp
+28
-4
csrc/layer_norm_cuda_kernel.cu
csrc/layer_norm_cuda_kernel.cu
+111
-89
csrc/megatron/scaled_masked_softmax.cpp
csrc/megatron/scaled_masked_softmax.cpp
+97
-0
csrc/megatron/scaled_masked_softmax.h
csrc/megatron/scaled_masked_softmax.h
+505
-0
csrc/megatron/scaled_masked_softmax_cuda.cu
csrc/megatron/scaled_masked_softmax_cuda.cu
+117
-0
csrc/megatron/scaled_upper_triang_masked_softmax.cpp
csrc/megatron/scaled_upper_triang_masked_softmax.cpp
+72
-0
csrc/megatron/scaled_upper_triang_masked_softmax.h
csrc/megatron/scaled_upper_triang_masked_softmax.h
+513
-0
csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu
csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu
+98
-0
csrc/mlp.cpp
csrc/mlp.cpp
+10
-16
csrc/mlp_cuda.cu
csrc/mlp_cuda.cu
+344
-43
csrc/multi_tensor_l2norm_kernel.cu
csrc/multi_tensor_l2norm_kernel.cu
+5
-0
csrc/multi_tensor_l2norm_scale_kernel.cu
csrc/multi_tensor_l2norm_scale_kernel.cu
+326
-0
No files found.
apex/transformer/tensor_parallel/tests/__init__.py
0 → 100644
View file @
f79993d9
apex/transformer/tensor_parallel/tests/arguments.py
0 → 100644
View file @
f79993d9
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron arguments."""
import
argparse
import
os
import
torch
def
parse_args
(
extra_args_provider
=
None
,
defaults
=
{},
ignore_unknown_args
=
False
):
"""Parse all arguments."""
parser
=
argparse
.
ArgumentParser
(
description
=
'Megatron-LM Arguments'
,
allow_abbrev
=
False
)
# Standard arguments.
parser
=
_add_network_size_args
(
parser
)
parser
=
_add_regularization_args
(
parser
)
parser
=
_add_training_args
(
parser
)
parser
=
_add_initialization_args
(
parser
)
parser
=
_add_learning_rate_args
(
parser
)
parser
=
_add_checkpointing_args
(
parser
)
parser
=
_add_mixed_precision_args
(
parser
)
parser
=
_add_distributed_args
(
parser
)
parser
=
_add_validation_args
(
parser
)
parser
=
_add_data_args
(
parser
)
parser
=
_add_autoresume_args
(
parser
)
parser
=
_add_biencoder_args
(
parser
)
parser
=
_add_vit_args
(
parser
)
parser
=
_add_logging_args
(
parser
)
# Custom arguments.
if
extra_args_provider
is
not
None
:
parser
=
extra_args_provider
(
parser
)
# Parse.
if
ignore_unknown_args
:
args
,
_
=
parser
.
parse_known_args
()
else
:
args
=
parser
.
parse_args
()
# Distributed args.
args
.
rank
=
int
(
os
.
getenv
(
'RANK'
,
'0'
))
args
.
world_size
=
int
(
os
.
getenv
(
"WORLD_SIZE"
,
'1'
))
# Tensor model parallel size.
args
.
tensor_model_parallel_size
=
min
(
args
.
tensor_model_parallel_size
,
args
.
world_size
)
assert
args
.
world_size
%
args
.
tensor_model_parallel_size
==
0
,
'world size'
\
' ({}) is not divisible by tensor model parallel size ({})'
.
format
(
args
.
world_size
,
args
.
tensor_model_parallel_size
)
# Pipeline model parallel size.
args
.
pipeline_model_parallel_size
=
min
(
args
.
pipeline_model_parallel_size
,
(
args
.
world_size
//
args
.
tensor_model_parallel_size
))
# Checks.
model_parallel_size
=
args
.
pipeline_model_parallel_size
*
\
args
.
tensor_model_parallel_size
assert
args
.
world_size
%
model_parallel_size
==
0
,
'world size is not'
\
' divisible by tensor parallel size ({}) times pipeline parallel '
\
'size ({})'
.
format
(
args
.
world_size
,
args
.
tensor_model_parallel_size
,
args
.
pipeline_model_parallel_size
)
args
.
data_parallel_size
=
args
.
world_size
//
model_parallel_size
if
args
.
rank
==
0
:
print
(
'using world size: {}, data-parallel-size: {}, '
'tensor-model-parallel size: {}, '
'pipeline-model-parallel size: {} '
.
format
(
args
.
world_size
,
args
.
data_parallel_size
,
args
.
tensor_model_parallel_size
,
args
.
pipeline_model_parallel_size
),
flush
=
True
)
# Deprecated arguments
assert
args
.
batch_size
is
None
,
'--batch-size argument is no longer '
\
'valid, use --micro-batch-size instead'
del
args
.
batch_size
assert
args
.
warmup
is
None
,
'--warmup argument is no longer valid, use '
\
'--lr-warmup-fraction instead'
del
args
.
warmup
assert
args
.
model_parallel_size
is
None
,
'--model-parallel-size is no '
\
'longer valid, use --tensor-model-parallel-size instead'
del
args
.
model_parallel_size
# Set input defaults.
for
key
in
defaults
:
# For default to be valid, it should not be provided in the
# arguments that are passed to the program. We check this by
# ensuring the arg is set to None.
if
getattr
(
args
,
key
)
is
not
None
:
if
args
.
rank
==
0
:
print
(
'WARNING: overriding default arguments for {key}:{v}
\
with {key}:{v2}'
.
format
(
key
=
key
,
v
=
defaults
[
key
],
v2
=
getattr
(
args
,
key
)),
flush
=
True
)
else
:
setattr
(
args
,
key
,
defaults
[
key
])
# Batch size.
assert
args
.
micro_batch_size
is
not
None
assert
args
.
micro_batch_size
>
0
if
args
.
global_batch_size
is
None
:
args
.
global_batch_size
=
args
.
micro_batch_size
*
args
.
data_parallel_size
if
args
.
rank
==
0
:
print
(
'setting global batch size to {}'
.
format
(
args
.
global_batch_size
),
flush
=
True
)
assert
args
.
global_batch_size
>
0
if
args
.
num_layers_per_virtual_pipeline_stage
is
not
None
:
assert
args
.
pipeline_model_parallel_size
>
2
,
\
'pipeline-model-parallel size should be greater than 2 with '
\
'interleaved schedule'
assert
args
.
num_layers
%
args
.
num_layers_per_virtual_pipeline_stage
==
0
,
\
'number of layers is not divisible by number of layers per virtual '
\
'pipeline stage'
args
.
virtual_pipeline_model_parallel_size
=
\
(
args
.
num_layers
//
args
.
pipeline_model_parallel_size
)
//
\
args
.
num_layers_per_virtual_pipeline_stage
else
:
args
.
virtual_pipeline_model_parallel_size
=
None
# Parameters dtype.
args
.
params_dtype
=
torch
.
float
if
args
.
fp16
:
assert
not
args
.
bf16
args
.
params_dtype
=
torch
.
half
if
args
.
bf16
:
assert
not
args
.
fp16
args
.
params_dtype
=
torch
.
bfloat16
# bfloat16 requires gradient accumulation and all-reduce to
# be done in fp32.
if
not
args
.
accumulate_allreduce_grads_in_fp32
:
args
.
accumulate_allreduce_grads_in_fp32
=
True
if
args
.
rank
==
0
:
print
(
'accumulate and all-reduce gradients in fp32 for '
'bfloat16 data type.'
,
flush
=
True
)
if
args
.
rank
==
0
:
print
(
'using {} for parameters ...'
.
format
(
args
.
params_dtype
),
flush
=
True
)
# If we do accumulation and all-reduces in fp32, we need to have
# local DDP and we should set the use-contiguous-buffers-in-ddp.
if
args
.
accumulate_allreduce_grads_in_fp32
:
assert
args
.
DDP_impl
==
'local'
args
.
use_contiguous_buffers_in_ddp
=
True
# If we use a contiguous buffer to hold main grads, we need to have
# local DDP.
if
args
.
use_contiguous_buffers_in_ddp
:
assert
args
.
DDP_impl
==
'local'
if
args
.
dataloader_type
is
None
:
args
.
dataloader_type
=
'single'
# Consumed tokens.
args
.
consumed_train_samples
=
0
args
.
consumed_valid_samples
=
0
# Iteration-based training.
if
args
.
train_iters
:
# If we use iteration-based training, make sure the
# sample-based options are off.
assert
args
.
train_samples
is
None
,
\
'expected iteration-based training'
assert
args
.
lr_decay_samples
is
None
,
\
'expected iteration-based learning rate decay'
assert
args
.
lr_warmup_samples
==
0
,
\
'expected iteration-based learning rate warmup'
assert
args
.
rampup_batch_size
is
None
,
\
'expected no batch-size rampup for iteration-based training'
if
args
.
lr_warmup_fraction
is
not
None
:
assert
args
.
lr_warmup_iters
==
0
,
\
'can only specify one of lr-warmup-fraction and lr-warmup-iters'
# Sample-based training.
if
args
.
train_samples
:
# If we use sample-based training, make sure the
# iteration-based options are off.
assert
args
.
train_iters
is
None
,
\
'expected sample-based training'
assert
args
.
lr_decay_iters
is
None
,
\
'expected sample-based learning rate decay'
assert
args
.
lr_warmup_iters
==
0
,
\
'expected sample-based learnig rate warmup'
if
args
.
lr_warmup_fraction
is
not
None
:
assert
args
.
lr_warmup_samples
==
0
,
\
'can only specify one of lr-warmup-fraction '
\
'and lr-warmup-samples'
# Check required arguments.
required_args
=
[
'num_layers'
,
'hidden_size'
,
'num_attention_heads'
,
'max_position_embeddings'
]
for
req_arg
in
required_args
:
_check_arg_is_not_none
(
args
,
req_arg
)
# Checks.
if
args
.
ffn_hidden_size
is
None
:
args
.
ffn_hidden_size
=
4
*
args
.
hidden_size
if
args
.
kv_channels
is
None
:
assert
args
.
hidden_size
%
args
.
num_attention_heads
==
0
args
.
kv_channels
=
args
.
hidden_size
//
args
.
num_attention_heads
if
args
.
seq_length
is
not
None
:
assert
args
.
encoder_seq_length
is
None
args
.
encoder_seq_length
=
args
.
seq_length
else
:
assert
args
.
encoder_seq_length
is
not
None
args
.
seq_length
=
args
.
encoder_seq_length
if
args
.
seq_length
is
not
None
:
assert
args
.
max_position_embeddings
>=
args
.
seq_length
if
args
.
decoder_seq_length
is
not
None
:
assert
args
.
max_position_embeddings
>=
args
.
decoder_seq_length
if
args
.
lr
is
not
None
:
assert
args
.
min_lr
<=
args
.
lr
if
args
.
save
is
not
None
:
assert
args
.
save_interval
is
not
None
# Mixed precision checks.
if
args
.
fp16_lm_cross_entropy
:
assert
args
.
fp16
,
'lm cross entropy in fp16 only support in fp16 mode.'
if
args
.
fp32_residual_connection
:
assert
args
.
fp16
or
args
.
bf16
,
\
'residual connection in fp32 only supported when using fp16 or bf16.'
# Activation checkpointing.
if
args
.
distribute_checkpointed_activations
:
assert
args
.
checkpoint_activations
,
\
'for distribute-checkpointed-activations to work you '
\
'need to enable checkpoint-activations'
_print_args
(
args
)
return
args
def
_print_args
(
args
):
"""Print arguments."""
if
args
.
rank
==
0
:
print
(
'------------------------ arguments ------------------------'
,
flush
=
True
)
str_list
=
[]
for
arg
in
vars
(
args
):
dots
=
'.'
*
(
48
-
len
(
arg
))
str_list
.
append
(
' {} {} {}'
.
format
(
arg
,
dots
,
getattr
(
args
,
arg
)))
for
arg
in
sorted
(
str_list
,
key
=
lambda
x
:
x
.
lower
()):
print
(
arg
,
flush
=
True
)
print
(
'-------------------- end of arguments ---------------------'
,
flush
=
True
)
def
_check_arg_is_not_none
(
args
,
arg
):
assert
getattr
(
args
,
arg
)
is
not
None
,
'{} argument is None'
.
format
(
arg
)
def
_add_network_size_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'network size'
)
group
.
add_argument
(
'--num-layers'
,
type
=
int
,
default
=
None
,
help
=
'Number of transformer layers.'
)
group
.
add_argument
(
'--hidden-size'
,
type
=
int
,
default
=
None
,
help
=
'Tansformer hidden size.'
)
group
.
add_argument
(
'--ffn-hidden-size'
,
type
=
int
,
default
=
None
,
help
=
'Transformer Feed-Forward Network hidden size. '
'This is set to 4*hidden-size if not provided'
)
group
.
add_argument
(
'--num-attention-heads'
,
type
=
int
,
default
=
None
,
help
=
'Number of transformer attention heads.'
)
group
.
add_argument
(
'--kv-channels'
,
type
=
int
,
default
=
None
,
help
=
'Projection weights dimension in multi-head '
'attention. This is set to '
' args.hidden_size // args.num_attention_heads '
'if not provided.'
)
group
.
add_argument
(
'--max-position-embeddings'
,
type
=
int
,
default
=
None
,
help
=
'Maximum number of position embeddings to use. '
'This is the size of position embedding.'
)
group
.
add_argument
(
'--make-vocab-size-divisible-by'
,
type
=
int
,
default
=
128
,
help
=
'Pad the vocab size to be divisible by this value.'
'This is added for computational efficieny reasons.'
)
group
.
add_argument
(
'--layernorm-epsilon'
,
type
=
float
,
default
=
1e-5
,
help
=
'Layer norm epsilon.'
)
group
.
add_argument
(
'--apply-residual-connection-post-layernorm'
,
action
=
'store_true'
,
help
=
'If set, use original BERT residula connection '
'ordering.'
)
group
.
add_argument
(
'--openai-gelu'
,
action
=
'store_true'
,
help
=
'Use OpenAIs GeLU implementation. This option'
'should not be used unless for backward compatibility'
'reasons.'
)
group
.
add_argument
(
'--onnx-safe'
,
type
=
bool
,
required
=
False
,
help
=
'Use workarounds for known problems with '
'Torch ONNX exporter'
)
group
.
add_argument
(
'--bert-no-binary-head'
,
action
=
'store_false'
,
help
=
'Disable BERT binary head.'
,
dest
=
'bert_binary_head'
)
return
parser
def
_add_logging_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'logging'
)
group
.
add_argument
(
'--log-params-norm'
,
action
=
'store_true'
,
help
=
'If set, calculate and log parameters norm.'
)
group
.
add_argument
(
'--log-num-zeros-in-grad'
,
action
=
'store_true'
,
help
=
'If set, calculate and log the number of zeros in gradient.'
)
group
.
add_argument
(
'--tensorboard-log-interval'
,
type
=
int
,
default
=
1
,
help
=
'Report to tensorboard interval.'
)
group
.
add_argument
(
'--tensorboard-queue-size'
,
type
=
int
,
default
=
1000
,
help
=
'Size of the tensorboard queue for pending events '
'and summaries before one of the ‘add’ calls forces a '
'flush to disk.'
)
group
.
add_argument
(
'--log-timers-to-tensorboard'
,
action
=
'store_true'
,
help
=
'If set, write timers to tensorboard.'
)
group
.
add_argument
(
'--log-batch-size-to-tensorboard'
,
action
=
'store_true'
,
help
=
'If set, write batch-size to tensorboard.'
)
group
.
add_argument
(
'--no-log-learnig-rate-to-tensorboard'
,
action
=
'store_false'
,
help
=
'Disable learning rate logging to tensorboard.'
,
dest
=
'log_learning_rate_to_tensorboard'
)
group
.
add_argument
(
'--no-log-loss-scale-to-tensorboard'
,
action
=
'store_false'
,
help
=
'Disable loss-scale logging to tensorboard.'
,
dest
=
'log_loss_scale_to_tensorboard'
)
group
.
add_argument
(
'--log-validation-ppl-to-tensorboard'
,
action
=
'store_true'
,
help
=
'If set, write validation perplexity to '
'tensorboard.'
)
group
.
add_argument
(
'--log-memory-to-tensorboard'
,
action
=
'store_true'
,
help
=
'Enable memory logging to tensorboard.'
)
return
parser
def
_add_regularization_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'regularization'
)
group
.
add_argument
(
'--attention-dropout'
,
type
=
float
,
default
=
0.1
,
help
=
'Post attention dropout probability.'
)
group
.
add_argument
(
'--hidden-dropout'
,
type
=
float
,
default
=
0.1
,
help
=
'Dropout probability for hidden state transformer.'
)
group
.
add_argument
(
'--weight-decay'
,
type
=
float
,
default
=
0.01
,
help
=
'Weight decay coefficient for L2 regularization.'
)
group
.
add_argument
(
'--clip-grad'
,
type
=
float
,
default
=
1.0
,
help
=
'Gradient clipping based on global L2 norm.'
)
group
.
add_argument
(
'--adam-beta1'
,
type
=
float
,
default
=
0.9
,
help
=
'First coefficient for computing running averages '
'of gradient and its square'
)
group
.
add_argument
(
'--adam-beta2'
,
type
=
float
,
default
=
0.999
,
help
=
'Second coefficient for computing running averages '
'of gradient and its square'
)
group
.
add_argument
(
'--adam-eps'
,
type
=
float
,
default
=
1e-08
,
help
=
'Term added to the denominator to improve'
'numerical stability'
)
group
.
add_argument
(
'--sgd-momentum'
,
type
=
float
,
default
=
0.9
,
help
=
'Momentum factor for sgd'
)
return
parser
def
_add_training_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'training'
)
group
.
add_argument
(
'--micro-batch-size'
,
type
=
int
,
default
=
None
,
help
=
'Batch size per model instance (local batch size). '
'Global batch size is local batch size times data '
'parallel size times number of micro batches.'
)
group
.
add_argument
(
'--batch-size'
,
type
=
int
,
default
=
None
,
help
=
'Old batch size parameter, do not use. '
'Use --micro-batch-size instead'
)
group
.
add_argument
(
'--global-batch-size'
,
type
=
int
,
default
=
None
,
help
=
'Training batch size. If set, it should be a '
'multiple of micro-batch-size times data-parallel-size. '
'If this value is None, then '
'use micro-batch-size * data-parallel-size as the '
'global batch size. This choice will result in 1 for '
'number of micro-batches.'
)
group
.
add_argument
(
'--rampup-batch-size'
,
nargs
=
'*'
,
default
=
None
,
help
=
'Batch size ramp up with the following values:'
' --rampup-batch-size <start batch size> '
' <batch size incerement> '
' <ramp-up samples> '
'For example:'
' --rampup-batch-size 16 8 300000 \ '
' --global-batch-size 1024'
'will start with global batch size 16 and over '
' (1024 - 16) / 8 = 126 intervals will increase'
'the batch size linearly to 1024. In each interval'
'we will use approximately 300000 / 126 = 2380 samples.'
)
group
.
add_argument
(
'--checkpoint-activations'
,
action
=
'store_true'
,
help
=
'Checkpoint activation to allow for training '
'with larger models, sequences, and batch sizes.'
)
group
.
add_argument
(
'--distribute-checkpointed-activations'
,
action
=
'store_true'
,
help
=
'If set, distribute checkpointed activations '
'across model parallel group.'
)
group
.
add_argument
(
'--checkpoint-num-layers'
,
type
=
int
,
default
=
1
,
help
=
'chunk size (number of layers) for checkpointing.'
)
group
.
add_argument
(
'--train-iters'
,
type
=
int
,
default
=
None
,
help
=
'Total number of iterations to train over all '
'training runs. Note that either train-iters or '
'train-samples should be provided.'
)
group
.
add_argument
(
'--train-samples'
,
type
=
int
,
default
=
None
,
help
=
'Total number of samples to train over all '
'training runs. Note that either train-iters or '
'train-samples should be provided.'
)
group
.
add_argument
(
'--log-interval'
,
type
=
int
,
default
=
100
,
help
=
'Report loss and timing interval.'
)
group
.
add_argument
(
'--exit-interval'
,
type
=
int
,
default
=
None
,
help
=
'Exit the program after the iteration is divisible '
'by this value.'
)
group
.
add_argument
(
'--exit-duration-in-mins'
,
type
=
int
,
default
=
None
,
help
=
'Exit the program after this many minutes.'
)
group
.
add_argument
(
'--tensorboard-dir'
,
type
=
str
,
default
=
None
,
help
=
'Write TensorBoard logs to this directory.'
)
group
.
add_argument
(
'--no-masked-softmax-fusion'
,
action
=
'store_false'
,
help
=
'Disable fusion of query_key_value scaling, '
'masking, and softmax.'
,
dest
=
'masked_softmax_fusion'
)
group
.
add_argument
(
'--no-bias-gelu-fusion'
,
action
=
'store_false'
,
help
=
'Disable bias and gelu fusion.'
,
dest
=
'bias_gelu_fusion'
)
group
.
add_argument
(
'--no-bias-dropout-fusion'
,
action
=
'store_false'
,
help
=
'Disable bias and dropout fusion.'
,
dest
=
'bias_dropout_fusion'
)
group
.
add_argument
(
'--optimizer'
,
type
=
str
,
default
=
'adam'
,
choices
=
[
'adam'
,
'sgd'
],
help
=
'Optimizer function'
)
group
.
add_argument
(
'--dataloader-type'
,
type
=
str
,
default
=
None
,
choices
=
[
'single'
,
'cyclic'
],
help
=
'Single pass vs multiple pass data loader'
)
return
parser
def
_add_initialization_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'initialization'
)
group
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
1234
,
help
=
'Random seed used for python, numpy, '
'pytorch, and cuda.'
)
group
.
add_argument
(
'--init-method-std'
,
type
=
float
,
default
=
0.02
,
help
=
'Standard deviation of the zero mean normal '
'distribution used for weight initialization.'
)
group
.
add_argument
(
'--init-method-xavier-uniform'
,
action
=
'store_true'
,
help
=
'Enable Xavier uniform parameter initialization'
)
return
parser
def
_add_learning_rate_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'learning rate'
)
group
.
add_argument
(
'--lr'
,
type
=
float
,
default
=
None
,
help
=
'Initial learning rate. Depending on decay style '
'and initial warmup, the learing rate at each '
'iteration would be different.'
)
group
.
add_argument
(
'--lr-decay-style'
,
type
=
str
,
default
=
'linear'
,
choices
=
[
'constant'
,
'linear'
,
'cosine'
],
help
=
'Learning rate decay function.'
)
group
.
add_argument
(
'--lr-decay-iters'
,
type
=
int
,
default
=
None
,
help
=
'number of iterations to decay learning rate over,'
' If None defaults to `--train-iters`'
)
group
.
add_argument
(
'--lr-decay-samples'
,
type
=
int
,
default
=
None
,
help
=
'number of samples to decay learning rate over,'
' If None defaults to `--train-samples`'
)
group
.
add_argument
(
'--lr-warmup-fraction'
,
type
=
float
,
default
=
None
,
help
=
'fraction of lr-warmup-(iters/samples) to use '
'for warmup (as a float)'
)
group
.
add_argument
(
'--lr-warmup-iters'
,
type
=
int
,
default
=
0
,
help
=
'number of iterations to linearly warmup '
'learning rate over.'
)
group
.
add_argument
(
'--lr-warmup-samples'
,
type
=
int
,
default
=
0
,
help
=
'number of samples to linearly warmup '
'learning rate over.'
)
group
.
add_argument
(
'--warmup'
,
type
=
int
,
default
=
None
,
help
=
'Old lr warmup argument, do not use. Use one of the'
'--lr-warmup-* arguments above'
)
group
.
add_argument
(
'--min-lr'
,
type
=
float
,
default
=
0.0
,
help
=
'Minumum value for learning rate. The scheduler'
'clip values below this threshold.'
)
group
.
add_argument
(
'--override-lr-scheduler'
,
action
=
'store_true'
,
help
=
'Reset the values of the scheduler (learning rate,'
'warmup iterations, minimum learning rate, maximum '
'number of iterations, and decay style from input '
'arguments and ignore values from checkpoints. Note'
'that all the above values will be reset.'
)
group
.
add_argument
(
'--use-checkpoint-lr-scheduler'
,
action
=
'store_true'
,
help
=
'Use checkpoint to set the values of the scheduler '
'(learning rate, warmup iterations, minimum learning '
'rate, maximum number of iterations, and decay style '
'from checkpoint and ignore input arguments.'
)
return
parser
def
_add_checkpointing_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'checkpointing'
)
group
.
add_argument
(
'--save'
,
type
=
str
,
default
=
None
,
help
=
'Output directory to save checkpoints to.'
)
group
.
add_argument
(
'--save-interval'
,
type
=
int
,
default
=
None
,
help
=
'Number of iterations between checkpoint saves.'
)
group
.
add_argument
(
'--no-save-optim'
,
action
=
'store_true'
,
default
=
None
,
help
=
'Do not save current optimizer.'
)
group
.
add_argument
(
'--no-save-rng'
,
action
=
'store_true'
,
default
=
None
,
help
=
'Do not save current rng state.'
)
group
.
add_argument
(
'--load'
,
type
=
str
,
default
=
None
,
help
=
'Directory containing a model checkpoint.'
)
group
.
add_argument
(
'--no-load-optim'
,
action
=
'store_true'
,
default
=
None
,
help
=
'Do not load optimizer when loading checkpoint.'
)
group
.
add_argument
(
'--no-load-rng'
,
action
=
'store_true'
,
default
=
None
,
help
=
'Do not load rng state when loading checkpoint.'
)
group
.
add_argument
(
'--finetune'
,
action
=
'store_true'
,
help
=
'Load model for finetuning. Do not load optimizer '
'or rng state from checkpoint and set iteration to 0. '
'Assumed when loading a release checkpoint.'
)
return
parser
def
_add_mixed_precision_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'mixed precision'
)
group
.
add_argument
(
'--fp16'
,
action
=
'store_true'
,
help
=
'Run model in fp16 mode.'
)
group
.
add_argument
(
'--bf16'
,
action
=
'store_true'
,
help
=
'Run model in bfloat16 mode.'
)
group
.
add_argument
(
'--loss-scale'
,
type
=
float
,
default
=
None
,
help
=
'Static loss scaling, positive power of 2 '
'values can improve fp16 convergence. If None, dynamic'
'loss scaling is used.'
)
group
.
add_argument
(
'--initial-loss-scale'
,
type
=
float
,
default
=
2
**
32
,
help
=
'Initial loss-scale for dynamic loss scaling.'
)
group
.
add_argument
(
'--min-loss-scale'
,
type
=
float
,
default
=
1.0
,
help
=
'Minimum loss scale for dynamic loss scale.'
)
group
.
add_argument
(
'--loss-scale-window'
,
type
=
float
,
default
=
1000
,
help
=
'Window over which to raise/lower dynamic scale.'
)
group
.
add_argument
(
'--hysteresis'
,
type
=
int
,
default
=
2
,
help
=
'hysteresis for dynamic loss scaling'
)
group
.
add_argument
(
'--fp32-residual-connection'
,
action
=
'store_true'
,
help
=
'Move residual connections to fp32.'
)
group
.
add_argument
(
'--no-query-key-layer-scaling'
,
action
=
'store_false'
,
help
=
'Do not scale Q * K^T by 1 / layer-number.'
,
dest
=
'apply_query_key_layer_scaling'
)
group
.
add_argument
(
'--attention-softmax-in-fp32'
,
action
=
'store_true'
,
help
=
'Run attention masking and softmax in fp32. '
'This flag is ignored unless '
'--no-query-key-layer-scaling is specified.'
)
group
.
add_argument
(
'--accumulate-allreduce-grads-in-fp32'
,
action
=
'store_true'
,
help
=
'Gradient accumulation and all-reduce in fp32.'
)
group
.
add_argument
(
'--fp16-lm-cross-entropy'
,
action
=
'store_true'
,
help
=
'Move the cross entropy unreduced loss calculation'
'for lm head to fp16.'
)
return
parser
def
_add_distributed_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'distributed'
)
group
.
add_argument
(
'--tensor-model-parallel-size'
,
type
=
int
,
default
=
1
,
help
=
'Degree of tensor model parallelism.'
)
group
.
add_argument
(
'--pipeline-model-parallel-size'
,
type
=
int
,
default
=
1
,
help
=
'Degree of pipeline model parallelism.'
)
group
.
add_argument
(
'--model-parallel-size'
,
type
=
int
,
default
=
None
,
help
=
'Old model parallel argument, do not use. Use '
'--tensor-model-parallel-size instead.'
)
group
.
add_argument
(
'--num-layers-per-virtual-pipeline-stage'
,
type
=
int
,
default
=
None
,
help
=
'Number of layers per virtual pipeline stage'
)
group
.
add_argument
(
'--distributed-backend'
,
default
=
'nccl'
,
choices
=
[
'nccl'
,
'gloo'
],
help
=
'Which backend to use for distributed training.'
)
group
.
add_argument
(
'--DDP-impl'
,
default
=
'local'
,
choices
=
[
'local'
,
'torch'
],
help
=
'which DistributedDataParallel implementation '
'to use.'
)
group
.
add_argument
(
'--use-contiguous-buffers-in-ddp'
,
action
=
'store_true'
,
help
=
'If set, use contiguous buffer in DDP. Note that '
'this option only works woth local DDP.'
)
group
.
add_argument
(
'--no-scatter-gather-tensors-in-pipeline'
,
action
=
'store_false'
,
help
=
'Use scatter/gather to optimize communication of tensors in pipeline'
,
dest
=
'scatter_gather_tensors_in_pipeline'
)
group
.
add_argument
(
'--local_rank'
,
type
=
int
,
default
=
None
,
help
=
'local rank passed from distributed launcher.'
)
group
.
add_argument
(
'--lazy-mpu-init'
,
type
=
bool
,
required
=
False
,
help
=
'If set to True, initialize_megatron() '
'skips DDP initialization and returns function to '
'complete it instead.Also turns on '
'--use-cpu-initialization flag. This is for '
'external DDP manager.'
)
group
.
add_argument
(
'--use-cpu-initialization'
,
action
=
'store_true'
,
default
=
None
,
help
=
'If set, affine parallel weights '
'initialization uses CPU'
)
group
.
add_argument
(
'--empty-unused-memory-level'
,
default
=
0
,
type
=
int
,
choices
=
[
0
,
1
,
2
],
help
=
'Call torch.cuda.empty_cache() each iteration '
'(training and eval), to reduce fragmentation.'
'0=off, 1=moderate, 2=aggressive.'
)
return
parser
def
_add_validation_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'validation'
)
group
.
add_argument
(
'--eval-iters'
,
type
=
int
,
default
=
100
,
help
=
'Number of iterations to run for evaluation'
'validation/test for.'
)
group
.
add_argument
(
'--eval-interval'
,
type
=
int
,
default
=
1000
,
help
=
'Interval between running evaluation on '
'validation set.'
)
return
parser
def
_add_data_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'data and dataloader'
)
group
.
add_argument
(
'--data-path'
,
nargs
=
'*'
,
default
=
None
,
help
=
'Path to the training dataset. Accepted format:'
'1) a single data path, 2) multiple datasets in the'
'form: dataset1-weight dataset1-path dataset2-weight '
'dataset2-path ...'
)
group
.
add_argument
(
'--split'
,
type
=
str
,
default
=
'969, 30, 1'
,
help
=
'Comma-separated list of proportions for training,'
' validation, and test split. For example the split '
'`90,5,5` will use 90%% of data for training, 5%% for '
'validation and 5%% for test.'
)
group
.
add_argument
(
'--vocab-file'
,
type
=
str
,
default
=
None
,
help
=
'Path to the vocab file.'
)
group
.
add_argument
(
'--merge-file'
,
type
=
str
,
default
=
None
,
help
=
'Path to the BPE merge file.'
)
group
.
add_argument
(
'--vocab-extra-ids'
,
type
=
int
,
default
=
0
,
help
=
'Number of additional vocabulary tokens. '
'They are used for span masking in the T5 model'
)
group
.
add_argument
(
'--seq-length'
,
type
=
int
,
default
=
None
,
help
=
'Maximum sequence length to process.'
)
group
.
add_argument
(
'--encoder-seq-length'
,
type
=
int
,
default
=
None
,
help
=
'Maximum encoder sequence length to process.'
'This should be exclusive of --seq-length'
)
group
.
add_argument
(
'--decoder-seq-length'
,
type
=
int
,
default
=
None
,
help
=
"Maximum decoder sequence length to process."
)
group
.
add_argument
(
'--retriever-seq-length'
,
type
=
int
,
default
=
256
,
help
=
'Maximum sequence length for the biencoder model '
' for retriever'
)
group
.
add_argument
(
'--sample-rate'
,
type
=
float
,
default
=
1.0
,
help
=
'sample rate for training data. Supposed to be 0 '
' < sample_rate < 1'
)
group
.
add_argument
(
'--mask-prob'
,
type
=
float
,
default
=
0.15
,
help
=
'Probability of replacing a token with mask.'
)
group
.
add_argument
(
'--short-seq-prob'
,
type
=
float
,
default
=
0.1
,
help
=
'Probability of producing a short sequence.'
)
group
.
add_argument
(
'--mmap-warmup'
,
action
=
'store_true'
,
help
=
'Warm up mmap files.'
)
group
.
add_argument
(
'--num-workers'
,
type
=
int
,
default
=
2
,
help
=
"Dataloader number of workers."
)
group
.
add_argument
(
'--tokenizer-type'
,
type
=
str
,
default
=
None
,
choices
=
[
'BertWordPieceLowerCase'
,
'BertWordPieceCase'
,
'GPT2BPETokenizer'
],
help
=
'What type of tokenizer to use.'
)
group
.
add_argument
(
'--data-impl'
,
type
=
str
,
default
=
'infer'
,
choices
=
[
'lazy'
,
'cached'
,
'mmap'
,
'infer'
],
help
=
'Implementation of indexed datasets.'
)
group
.
add_argument
(
'--reset-position-ids'
,
action
=
'store_true'
,
help
=
'Reset posistion ids after end-of-document token.'
)
group
.
add_argument
(
'--reset-attention-mask'
,
action
=
'store_true'
,
help
=
'Reset self attention maske after '
'end-of-document token.'
)
group
.
add_argument
(
'--eod-mask-loss'
,
action
=
'store_true'
,
help
=
'Mask loss for the end of document tokens.'
)
return
parser
def
_add_autoresume_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'autoresume'
)
group
.
add_argument
(
'--adlr-autoresume'
,
action
=
'store_true'
,
help
=
'Enable autoresume on adlr cluster.'
)
group
.
add_argument
(
'--adlr-autoresume-interval'
,
type
=
int
,
default
=
1000
,
help
=
'Intervals over which check for autoresume'
'termination signal'
)
return
parser
def
_add_biencoder_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'biencoder'
)
# network size
group
.
add_argument
(
'--ict-head-size'
,
type
=
int
,
default
=
None
,
help
=
'Size of block embeddings to be used in ICT and '
'REALM (paper default: 128)'
)
group
.
add_argument
(
'--biencoder-projection-dim'
,
type
=
int
,
default
=
0
,
help
=
'Size of projection head used in biencoder (paper'
' default: 128)'
)
group
.
add_argument
(
'--biencoder-shared-query-context-model'
,
action
=
'store_true'
,
help
=
'Whether to share the parameters of the query '
'and context models or not'
)
# checkpointing
group
.
add_argument
(
'--ict-load'
,
type
=
str
,
default
=
None
,
help
=
'Directory containing an ICTBertModel checkpoint'
)
group
.
add_argument
(
'--bert-load'
,
type
=
str
,
default
=
None
,
help
=
'Directory containing an BertModel checkpoint '
'(needed to start ICT and REALM)'
)
# data
group
.
add_argument
(
'--titles-data-path'
,
type
=
str
,
default
=
None
,
help
=
'Path to titles dataset used for ICT'
)
group
.
add_argument
(
'--query-in-block-prob'
,
type
=
float
,
default
=
0.1
,
help
=
'Probability of keeping query in block for '
'ICT dataset'
)
group
.
add_argument
(
'--use-one-sent-docs'
,
action
=
'store_true'
,
help
=
'Whether to use one sentence documents in ICT'
)
group
.
add_argument
(
'--evidence-data-path'
,
type
=
str
,
default
=
None
,
help
=
'Path to Wikipedia Evidence frm DPR paper'
)
# training
group
.
add_argument
(
'--retriever-report-topk-accuracies'
,
nargs
=
'+'
,
type
=
int
,
default
=
[],
help
=
"Which top-k accuracies to report "
"(e.g. '1 5 20')"
)
group
.
add_argument
(
'--retriever-score-scaling'
,
action
=
'store_true'
,
help
=
'Whether to scale retriever scores by inverse '
'square root of hidden size'
)
# faiss index
group
.
add_argument
(
'--block-data-path'
,
type
=
str
,
default
=
None
,
help
=
'Where to save/load BlockData to/from'
)
group
.
add_argument
(
'--embedding-path'
,
type
=
str
,
default
=
None
,
help
=
'Where to save/load Open-Retrieval Embedding'
' data to/from'
)
# indexer
group
.
add_argument
(
'--indexer-batch-size'
,
type
=
int
,
default
=
128
,
help
=
'How large of batches to use when doing indexing '
'jobs'
)
group
.
add_argument
(
'--indexer-log-interval'
,
type
=
int
,
default
=
1000
,
help
=
'After how many batches should the indexer '
'report progress'
)
return
parser
def
_add_vit_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
"vit"
)
group
.
add_argument
(
'--num-classes'
,
type
=
int
,
default
=
1000
,
help
=
'num of classes in vision classificaiton task'
)
group
.
add_argument
(
'--img-dim'
,
type
=
int
,
default
=
224
,
help
=
'Image size for vision classification task'
)
group
.
add_argument
(
'--num-channels'
,
type
=
int
,
default
=
3
,
help
=
'Number of channels in input image data'
)
group
.
add_argument
(
'--patch-dim'
,
type
=
int
,
default
=
16
,
help
=
'patch dimension used in vit'
)
return
parser
apex/transformer/tensor_parallel/tests/commons.py
0 → 100644
View file @
f79993d9
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
random
import
numpy
import
torch
from
apex
import
transformer
from
apex.transformer.tensor_parallel.tests
import
global_vars
TEST_SUCCESS_MESSAGE
=
">> passed the test :-)"
class
IdentityLayer
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
size
,
scale
=
1.0
):
super
(
IdentityLayer
,
self
).
__init__
()
self
.
weight
=
torch
.
nn
.
Parameter
(
scale
*
torch
.
randn
(
size
))
def
forward
(
self
):
return
self
.
weight
def
set_random_seed
(
seed
):
"""Set random seed for reproducibility."""
random
.
seed
(
seed
)
numpy
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
transformer
.
tensor_parallel
.
model_parallel_cuda_manual_seed
(
seed
)
def
initialize_distributed
(
backend
=
'nccl'
):
"""Initialize torch.distributed."""
# Get local rank in case it is provided.
# parser = argparse.ArgumentParser()
# parser.add_argument('--local_rank', type=int, default=None,
# help='local rank passed from distributed launcher')
# args = parser.parse_args()
args
=
global_vars
.
get_args
()
local_rank
=
args
.
local_rank
# Get rank and world size.
rank
=
int
(
os
.
getenv
(
'RANK'
,
'0'
))
world_size
=
int
(
os
.
getenv
(
"WORLD_SIZE"
,
'1'
))
print
(
'> initializing torch.distributed with local rank: {}, '
'rank: {}, world size: {}'
.
format
(
local_rank
,
rank
,
world_size
))
# Set the device id.
device
=
rank
%
torch
.
cuda
.
device_count
()
if
local_rank
is
not
None
:
device
=
local_rank
torch
.
cuda
.
set_device
(
device
)
# Call the init process.
init_method
=
'tcp://'
master_ip
=
os
.
getenv
(
'MASTER_ADDR'
,
'localhost'
)
master_port
=
os
.
getenv
(
'MASTER_PORT'
,
'6000'
)
init_method
+=
master_ip
+
':'
+
master_port
torch
.
distributed
.
init_process_group
(
backend
=
backend
,
world_size
=
world_size
,
rank
=
rank
,
init_method
=
init_method
)
def
print_separator
(
message
):
torch
.
distributed
.
barrier
()
filler_len
=
(
78
-
len
(
message
))
//
2
filler
=
'-'
*
filler_len
string
=
'
\n
'
+
filler
+
' {} '
.
format
(
message
)
+
filler
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
string
,
flush
=
True
)
torch
.
distributed
.
barrier
()
apex/transformer/tensor_parallel/tests/global_vars.py
0 → 100644
View file @
f79993d9
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron global variables."""
import
os
import
sys
import
time
import
torch
from
apex.transformer.tensor_parallel.microbatches
import
build_num_microbatches_calculator
from
apex.transformer.tensor_parallel.tests.arguments
import
parse_args
_GLOBAL_ARGS
=
None
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
=
None
_GLOBAL_TOKENIZER
=
None
_GLOBAL_TENSORBOARD_WRITER
=
None
_GLOBAL_ADLR_AUTORESUME
=
None
_GLOBAL_TIMERS
=
None
def
get_args
():
"""Return arguments."""
_ensure_var_is_initialized
(
_GLOBAL_ARGS
,
'args'
)
return
_GLOBAL_ARGS
def
get_num_microbatches
():
return
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
.
get
()
def
get_current_global_batch_size
():
return
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
.
get_current_global_batch_size
()
def
update_num_microbatches
(
consumed_samples
,
consistency_check
=
True
):
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
.
update
(
consumed_samples
,
consistency_check
)
# def get_tokenizer():
# """Return tokenizer."""
# _ensure_var_is_initialized(_GLOBAL_TOKENIZER, 'tokenizer')
# return _GLOBAL_TOKENIZER
def
get_tensorboard_writer
():
"""Return tensorboard writer. It can be None so no need
to check if it is initialized."""
return
_GLOBAL_TENSORBOARD_WRITER
def
get_adlr_autoresume
():
"""ADLR autoresume object. It can be None so no need
to check if it is initialized."""
return
_GLOBAL_ADLR_AUTORESUME
def
get_timers
():
"""Return timers."""
_ensure_var_is_initialized
(
_GLOBAL_TIMERS
,
'timers'
)
return
_GLOBAL_TIMERS
def
set_global_variables
(
extra_args_provider
=
None
,
args_defaults
=
{},
ignore_unknown_args
=
False
):
"""Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers."""
args
=
_parse_args
(
extra_args_provider
=
extra_args_provider
,
defaults
=
args_defaults
,
ignore_unknown_args
=
ignore_unknown_args
)
_build_num_microbatches_calculator
(
args
)
# if args.vocab_file:
# _ = _build_tokenizer(args)
_set_tensorboard_writer
(
args
)
_set_adlr_autoresume
(
args
)
_set_timers
()
def
_parse_args
(
extra_args_provider
=
None
,
defaults
=
{},
ignore_unknown_args
=
False
):
"""Parse entire arguments."""
global
_GLOBAL_ARGS
_ensure_var_is_not_initialized
(
_GLOBAL_ARGS
,
'args'
)
_GLOBAL_ARGS
=
parse_args
(
extra_args_provider
=
extra_args_provider
,
defaults
=
defaults
,
ignore_unknown_args
=
ignore_unknown_args
)
return
_GLOBAL_ARGS
def
_build_num_microbatches_calculator
(
args
):
global
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
_ensure_var_is_not_initialized
(
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
,
'num microbatches calculator'
)
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
=
build_num_microbatches_calculator
(
args
)
# def _build_tokenizer(args):
# """Initialize tokenizer."""
# global _GLOBAL_TOKENIZER
# _ensure_var_is_not_initialized(_GLOBAL_TOKENIZER, 'tokenizer')
# _GLOBAL_TOKENIZER = build_tokenizer(args)
# return _GLOBAL_TOKENIZER
# def rebuild_tokenizer(args):
# global _GLOBAL_TOKENIZER
# _GLOBAL_TOKENIZER = None
# return _build_tokenizer(args)
def
_set_tensorboard_writer
(
args
):
"""Set tensorboard writer."""
global
_GLOBAL_TENSORBOARD_WRITER
_ensure_var_is_not_initialized
(
_GLOBAL_TENSORBOARD_WRITER
,
'tensorboard writer'
)
if
hasattr
(
args
,
'tensorboard_dir'
)
and
\
args
.
tensorboard_dir
and
args
.
rank
==
(
args
.
world_size
-
1
):
try
:
from
torch.utils.tensorboard
import
SummaryWriter
print
(
'> setting tensorboard ...'
)
_GLOBAL_TENSORBOARD_WRITER
=
SummaryWriter
(
log_dir
=
args
.
tensorboard_dir
,
max_queue
=
args
.
tensorboard_queue_size
)
except
ModuleNotFoundError
:
print
(
'WARNING: TensorBoard writing requested but is not '
'available (are you using PyTorch 1.1.0 or later?), '
'no TensorBoard logs will be written.'
,
flush
=
True
)
def
_set_adlr_autoresume
(
args
):
"""Initialize ADLR autoresume."""
global
_GLOBAL_ADLR_AUTORESUME
_ensure_var_is_not_initialized
(
_GLOBAL_ADLR_AUTORESUME
,
'adlr autoresume'
)
if
args
.
adlr_autoresume
:
if
args
.
rank
==
0
:
print
(
'enabling autoresume ...'
,
flush
=
True
)
sys
.
path
.
append
(
os
.
environ
.
get
(
'SUBMIT_SCRIPTS'
,
'.'
))
try
:
from
userlib.auto_resume
import
AutoResume
except
BaseException
:
print
(
'ADLR autoresume is not available, exiting ...'
)
sys
.
exit
()
_GLOBAL_ADLR_AUTORESUME
=
AutoResume
def
_set_timers
():
"""Initialize timers."""
global
_GLOBAL_TIMERS
_ensure_var_is_not_initialized
(
_GLOBAL_TIMERS
,
'timers'
)
_GLOBAL_TIMERS
=
Timers
()
def
_ensure_var_is_initialized
(
var
,
name
):
"""Make sure the input variable is not None."""
assert
var
is
not
None
,
'{} is not initialized.'
.
format
(
name
)
def
_ensure_var_is_not_initialized
(
var
,
name
):
"""Make sure the input variable is not None."""
assert
var
is
None
,
'{} is already initialized.'
.
format
(
name
)
class
_Timer
:
"""Timer."""
def
__init__
(
self
,
name
):
self
.
name_
=
name
self
.
elapsed_
=
0.0
self
.
started_
=
False
self
.
start_time
=
time
.
time
()
def
start
(
self
):
"""Start the timer."""
assert
not
self
.
started_
,
'timer has already been started'
torch
.
cuda
.
synchronize
()
self
.
start_time
=
time
.
time
()
self
.
started_
=
True
def
stop
(
self
):
"""Stop the timer."""
assert
self
.
started_
,
'timer is not started'
torch
.
cuda
.
synchronize
()
self
.
elapsed_
+=
(
time
.
time
()
-
self
.
start_time
)
self
.
started_
=
False
def
reset
(
self
):
"""Reset timer."""
self
.
elapsed_
=
0.0
self
.
started_
=
False
def
elapsed
(
self
,
reset
=
True
):
"""Calculate the elapsed time."""
started_
=
self
.
started_
# If the timing in progress, end it first.
if
self
.
started_
:
self
.
stop
()
# Get the elapsed time.
elapsed_
=
self
.
elapsed_
# Reset the elapsed time
if
reset
:
self
.
reset
()
# If timing was in progress, set it back.
if
started_
:
self
.
start
()
return
elapsed_
class
Timers
:
"""Group of timers."""
def
__init__
(
self
):
self
.
timers
=
{}
def
__call__
(
self
,
name
):
if
name
not
in
self
.
timers
:
self
.
timers
[
name
]
=
_Timer
(
name
)
return
self
.
timers
[
name
]
def
write
(
self
,
names
,
writer
,
iteration
,
normalizer
=
1.0
,
reset
=
False
):
"""Write timers to a tensorboard writer"""
# currently when using add_scalars,
# torch.utils.add_scalars makes each timer its own run, which
# polutes the runs list, so we just add each as a scalar
assert
normalizer
>
0.0
for
name
in
names
:
value
=
self
.
timers
[
name
].
elapsed
(
reset
=
reset
)
/
normalizer
writer
.
add_scalar
(
name
+
'-time'
,
value
,
iteration
)
def
log
(
self
,
names
,
normalizer
=
1.0
,
reset
=
True
):
"""Log a group of timers."""
assert
normalizer
>
0.0
string
=
'time (ms)'
for
name
in
names
:
elapsed_time
=
self
.
timers
[
name
].
elapsed
(
reset
=
reset
)
*
1000.0
/
normalizer
string
+=
' | {}: {:.2f}'
.
format
(
name
,
elapsed_time
)
if
torch
.
distributed
.
is_initialized
():
if
torch
.
distributed
.
get_rank
()
==
(
torch
.
distributed
.
get_world_size
()
-
1
):
print
(
string
,
flush
=
True
)
else
:
print
(
string
,
flush
=
True
)
apex/transformer/tensor_parallel/utils.py
0 → 100644
View file @
f79993d9
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
def
ensure_divisibility
(
numerator
,
denominator
):
"""Ensure that numerator is divisible by the denominator."""
assert
numerator
%
denominator
==
0
,
"{} is not divisible by {}"
.
format
(
numerator
,
denominator
)
def
divide
(
numerator
,
denominator
):
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility
(
numerator
,
denominator
)
return
numerator
//
denominator
def
split_tensor_along_last_dim
(
tensor
,
num_partitions
,
contiguous_split_chunks
=
False
):
"""Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
"""
# Get the size and dimension.
last_dim
=
tensor
.
dim
()
-
1
last_dim_size
=
divide
(
tensor
.
size
()[
last_dim
],
num_partitions
)
# Split.
tensor_list
=
torch
.
split
(
tensor
,
last_dim_size
,
dim
=
last_dim
)
# Note: torch.split does not create contiguous tensors by default.
if
contiguous_split_chunks
:
return
tuple
(
chunk
.
contiguous
()
for
chunk
in
tensor_list
)
return
tensor_list
class
VocabUtility
:
"""Split the vocabulary into `world_size` chunks amd return the
first and last index of the vocabulary belonging to the `rank`
partition: Note that indecies in [fist, last)"""
@
staticmethod
def
vocab_range_from_per_partition_vocab_size
(
per_partition_vocab_size
,
rank
,
world_size
):
index_f
=
rank
*
per_partition_vocab_size
index_l
=
index_f
+
per_partition_vocab_size
return
index_f
,
index_l
@
staticmethod
def
vocab_range_from_global_vocab_size
(
global_vocab_size
,
rank
,
world_size
):
per_partition_vocab_size
=
divide
(
global_vocab_size
,
world_size
)
return
VocabUtility
.
vocab_range_from_per_partition_vocab_size
(
per_partition_vocab_size
,
rank
,
world_size
)
csrc/amp_C_frontend.cpp
View file @
f79993d9
...
@@ -33,6 +33,13 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
...
@@ -33,6 +33,13 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
optional
<
bool
>
per_tensor_python
);
at
::
optional
<
bool
>
per_tensor_python
);
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
multi_tensor_l2norm_scale_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
float
scale
,
at
::
optional
<
bool
>
per_tensor_python
);
void
multi_tensor_lamb_stage1_cuda
(
void
multi_tensor_lamb_stage1_cuda
(
int
chunk_size
,
int
chunk_size
,
at
::
Tensor
noop_flag
,
at
::
Tensor
noop_flag
,
...
@@ -121,6 +128,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -121,6 +128,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"out = a*x + b*y for a list of contiguous tensors"
);
"out = a*x + b*y for a list of contiguous tensors"
);
m
.
def
(
"multi_tensor_l2norm"
,
&
multi_tensor_l2norm_cuda
,
m
.
def
(
"multi_tensor_l2norm"
,
&
multi_tensor_l2norm_cuda
,
"Computes L2 norm for a list of contiguous tensors"
);
"Computes L2 norm for a list of contiguous tensors"
);
m
.
def
(
"multi_tensor_l2norm_scale"
,
&
multi_tensor_l2norm_scale_cuda
,
"Computes L2 norm for a list of contiguous tensors and does scaling"
);
m
.
def
(
"multi_tensor_lamb_stage1_cuda"
,
&
multi_tensor_lamb_stage1_cuda
,
m
.
def
(
"multi_tensor_lamb_stage1_cuda"
,
&
multi_tensor_lamb_stage1_cuda
,
"Computes update part of LAMB optimizer"
);
"Computes update part of LAMB optimizer"
);
m
.
def
(
"multi_tensor_lamb_stage2_cuda"
,
&
multi_tensor_lamb_stage2_cuda
,
m
.
def
(
"multi_tensor_lamb_stage2_cuda"
,
&
multi_tensor_lamb_stage2_cuda
,
...
...
csrc/fused_dense.cpp
0 → 100644
View file @
f79993d9
#include <torch/extension.h>
#include <torch/torch.h>
#include <vector>
#include <stdio.h>
template
<
typename
T
>
int
linear_bias_forward_cuda
(
at
::
Tensor
input
,
T
*
weight
,
at
::
Tensor
bias
,
int
in_features
,
int
batch_size
,
int
out_features
,
at
::
Tensor
output
,
void
*
lt_workspace
);
template
<
typename
T
>
int
linear_bias_backward_cuda
(
T
*
input
,
T
*
weight
,
T
*
d_output
,
int
in_features
,
int
batch_size
,
int
out_features
,
T
*
d_weight
,
T
*
d_bias
,
T
*
d_input
,
void
*
lt_workspace
);
template
<
typename
T
>
int
linear_gelu_linear_forward_cuda
(
T
*
input
,
T
*
weight1
,
T
*
bias1
,
T
*
weight2
,
T
*
bias2
,
int
in_features
,
int
hidden_features
,
int
batch_size
,
int
out_features
,
T
*
output1
,
T
*
output2
,
T
*
gelu_in
,
void
*
lt_workspace
)
;
template
<
typename
T
>
int
linear_gelu_linear_backward_cuda
(
T
*
input
,
T
*
gelu_in
,
T
*
output1
,
T
*
weight1
,
T
*
weight2
,
T
*
d_output1
,
T
*
d_output2
,
int
in_features
,
int
batch_size
,
int
hidden_features
,
int
out_features
,
T
*
d_weight1
,
T
*
d_weight2
,
T
*
d_bias1
,
T
*
d_bias2
,
T
*
d_input
,
void
*
lt_workspace
);
at
::
Tensor
linear_bias_forward
(
at
::
Tensor
input
,
at
::
Tensor
weight
,
at
::
Tensor
bias
)
{
auto
batch_size
=
input
.
size
(
0
);
auto
in_features
=
input
.
size
(
1
);
int
out_features
=
weight
.
size
(
0
);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
// create output/workspace tensor
auto
out
=
at
::
empty
({
batch_size
,
out_features
},
input
.
type
());
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto
lt_workspace
=
at
::
empty
({
1
<<
22
},
input
.
type
());
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"linear_bias_forward"
,
[
&
]
{
scalar_t
*
w_ptr
=
weight
.
data_ptr
<
scalar_t
>
();
scalar_t
*
b_ptr
=
bias
.
data_ptr
<
scalar_t
>
();
auto
result
=
linear_bias_forward_cuda
<
scalar_t
>
(
input
,
w_ptr
,
bias
,
in_features
,
batch_size
,
out_features
,
out
,
//out.data_ptr<scalar_t>(),
// reserved_space.data_ptr<scalar_t>(),
(
void
*
)
(
lt_workspace
.
data_ptr
<
scalar_t
>
()));
});
return
{
out
};
}
std
::
vector
<
at
::
Tensor
>
linear_bias_backward
(
at
::
Tensor
input
,
at
::
Tensor
weight
,
at
::
Tensor
d_output
)
{
auto
batch_size
=
input
.
size
(
0
);
auto
in_features
=
input
.
size
(
1
);
int
out_features
=
weight
.
size
(
0
);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
// create output/workspace tensor
auto
d_weight
=
at
::
empty
({
out_features
,
in_features
},
input
.
type
());
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600
auto
d_bias
=
d_output
.
view
({
-
1
,
out_features
}).
sum
(
0
,
false
);
#else
auto
d_bias
=
at
::
empty
({
out_features
},
input
.
type
());
#endif
auto
d_input
=
at
::
empty
({
batch_size
,
in_features
},
input
.
type
());
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto
lt_workspace
=
at
::
empty
({
1
<<
22
},
input
.
type
());
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"linear_bias_backward"
,
[
&
]
{
scalar_t
*
w_ptr
=
weight
.
data_ptr
<
scalar_t
>
();
scalar_t
*
d_b_ptr
=
d_bias
.
data_ptr
<
scalar_t
>
();
auto
result
=
linear_bias_backward_cuda
<
scalar_t
>
(
input
.
data_ptr
<
scalar_t
>
(),
w_ptr
,
d_output
.
data_ptr
<
scalar_t
>
(),
in_features
,
batch_size
,
out_features
,
d_weight
.
data_ptr
<
scalar_t
>
(),
d_bias
.
data_ptr
<
scalar_t
>
(),
d_input
.
data_ptr
<
scalar_t
>
(),
// reserved_space.data_ptr<scalar_t>(),
(
void
*
)
(
lt_workspace
.
data_ptr
<
scalar_t
>
()));
});
return
{
d_input
,
d_weight
,
d_bias
};
}
std
::
vector
<
at
::
Tensor
>
linear_gelu_linear_forward
(
at
::
Tensor
input
,
at
::
Tensor
weight1
,
at
::
Tensor
bias1
,
at
::
Tensor
weight2
,
at
::
Tensor
bias2
)
{
auto
batch_size
=
input
.
size
(
0
);
auto
in_features
=
input
.
size
(
1
);
int
hidden_features
=
weight1
.
size
(
0
);
int
out_features
=
weight2
.
size
(
0
);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
// create output/workspace tensor
auto
output1
=
at
::
empty
({
batch_size
,
hidden_features
},
input
.
type
());
auto
gelu_in
=
at
::
empty
({
batch_size
,
hidden_features
},
input
.
type
());
auto
output2
=
at
::
empty
({
batch_size
,
out_features
},
input
.
type
());
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto
lt_workspace
=
at
::
empty
({
1
<<
22
},
input
.
type
());
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"linear_gelu_linear_forward"
,
[
&
]
{
scalar_t
*
w1_ptr
=
weight1
.
data_ptr
<
scalar_t
>
();
scalar_t
*
b1_ptr
=
bias1
.
data_ptr
<
scalar_t
>
();
scalar_t
*
w2_ptr
=
weight2
.
data_ptr
<
scalar_t
>
();
scalar_t
*
b2_ptr
=
bias2
.
data_ptr
<
scalar_t
>
();
auto
result
=
linear_gelu_linear_forward_cuda
<
scalar_t
>
(
input
.
data_ptr
<
scalar_t
>
(),
w1_ptr
,
b1_ptr
,
w2_ptr
,
b2_ptr
,
in_features
,
hidden_features
,
batch_size
,
out_features
,
output1
.
data_ptr
<
scalar_t
>
(),
output2
.
data_ptr
<
scalar_t
>
(),
gelu_in
.
data_ptr
<
scalar_t
>
(),
// reserved_space.data_ptr<scalar_t>(),
(
void
*
)
(
lt_workspace
.
data_ptr
<
scalar_t
>
()));
});
return
{
output1
,
output2
,
gelu_in
};
}
std
::
vector
<
at
::
Tensor
>
linear_gelu_linear_backward
(
at
::
Tensor
input
,
at
::
Tensor
gelu_in
,
at
::
Tensor
output1
,
at
::
Tensor
weight1
,
at
::
Tensor
weight2
,
at
::
Tensor
d_output2
)
{
auto
batch_size
=
input
.
size
(
0
);
auto
in_features
=
input
.
size
(
1
);
int
hidden_features
=
weight1
.
size
(
0
);
int
out_features
=
weight2
.
size
(
0
);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
// create output/workspace tensor
auto
d_weight1
=
at
::
empty
({
hidden_features
,
in_features
},
input
.
type
());
auto
d_weight2
=
at
::
empty
({
out_features
,
hidden_features
},
input
.
type
());
auto
d_bias1
=
at
::
empty
({
hidden_features
},
input
.
type
());
auto
d_bias2
=
at
::
empty
({
out_features
},
input
.
type
());
auto
d_input
=
at
::
empty
({
batch_size
,
in_features
},
input
.
type
());
auto
d_output1
=
at
::
empty
({
batch_size
,
hidden_features
},
input
.
type
());
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto
lt_workspace
=
at
::
empty
({
1
<<
22
},
input
.
type
());
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
input
.
type
(),
"linear_bias_backward"
,
[
&
]
{
//scalar_t* w_ptr = weight.data_ptr<scalar_t>();
//scalar_t* d_b_ptr = d_bias.data_ptr<scalar_t>();
auto
result
=
linear_gelu_linear_backward_cuda
<
scalar_t
>
(
input
.
data_ptr
<
scalar_t
>
(),
gelu_in
.
data_ptr
<
scalar_t
>
(),
output1
.
data_ptr
<
scalar_t
>
(),
weight1
.
data_ptr
<
scalar_t
>
(),
weight2
.
data_ptr
<
scalar_t
>
(),
d_output1
.
data_ptr
<
scalar_t
>
(),
d_output2
.
data_ptr
<
scalar_t
>
(),
in_features
,
batch_size
,
hidden_features
,
out_features
,
d_weight1
.
data_ptr
<
scalar_t
>
(),
d_weight2
.
data_ptr
<
scalar_t
>
(),
d_bias1
.
data_ptr
<
scalar_t
>
(),
d_bias2
.
data_ptr
<
scalar_t
>
(),
d_input
.
data_ptr
<
scalar_t
>
(),
// reserved_space.data_ptr<scalar_t>(),
(
void
*
)
(
lt_workspace
.
data_ptr
<
scalar_t
>
()));
});
return
{
d_input
,
d_weight1
,
d_bias1
,
d_weight2
,
d_bias2
};
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"linear_bias_forward"
,
&
linear_bias_forward
,
"linear bias forward"
);
m
.
def
(
"linear_bias_backward"
,
&
linear_bias_backward
,
"linear bias backward"
);
m
.
def
(
"linear_gelu_linear_forward"
,
&
linear_gelu_linear_forward
,
"linear gelu linear forward"
);
m
.
def
(
"linear_gelu_linear_backward"
,
&
linear_gelu_linear_backward
,
"linear gelu linear backward"
);
}
csrc/fused_dense_cuda.cu
0 → 100644
View file @
f79993d9
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <torch/torch.h>
/* Includes, cuda */
#include <cublas_v2.h>
#include <cuda_runtime.h>
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
// includes cublaslt
#include <cublasLt.h>
#endif
// FP64 Wrapper around cublas GEMMEx
cublasStatus_t
gemm_bias
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
double
*
A
,
int
lda
,
double
*
B
,
int
ldb
,
const
float
*
beta
,
double
*
C
,
int
ldc
)
{
return
cublasGemmEx
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
CUDA_R_64F
,
lda
,
B
,
CUDA_R_64F
,
ldb
,
beta
,
C
,
CUDA_R_64F
,
ldc
,
CUDA_R_64F
,
CUBLAS_GEMM_DEFAULT
);
}
// FP32 Wrapper around cublas GEMMEx
cublasStatus_t
gemm_bias
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
float
*
A
,
int
lda
,
float
*
B
,
int
ldb
,
const
float
*
beta
,
float
*
C
,
int
ldc
)
{
return
cublasGemmEx
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
CUDA_R_32F
,
lda
,
B
,
CUDA_R_32F
,
ldb
,
beta
,
C
,
CUDA_R_32F
,
ldc
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT
);
}
// FP16 Tensor core wrapper around cublas GEMMEx
cublasStatus_t
gemm_bias
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
at
::
Half
*
A
,
int
lda
,
at
::
Half
*
B
,
int
ldb
,
const
float
*
beta
,
at
::
Half
*
C
,
int
ldc
)
{
return
cublasGemmEx
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
CUDA_R_16F
,
lda
,
B
,
CUDA_R_16F
,
ldb
,
beta
,
C
,
CUDA_R_16F
,
ldc
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
}
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
int
gemm_bias_lt
(
cublasLtHandle_t
ltHandle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
/* host pointer */
at
::
Half
*
A
,
int
lda
,
at
::
Half
*
B
,
int
ldb
,
const
float
*
beta
,
/* host pointer */
at
::
Half
*
C
,
int
ldc
,
void
*
workspace
,
size_t
workspaceSize
,
cudaStream_t
stream
,
bool
use_bias
,
const
void
*
bias
)
{
cublasStatus_t
status
=
CUBLAS_STATUS_SUCCESS
;
cublasLtMatmulDescOpaque_t
operationDesc
=
{};
cublasLtMatrixLayoutOpaque_t
Adesc
=
{},
Bdesc
=
{},
Cdesc
=
{};
cublasLtMatmulPreferenceOpaque_t
preference
=
{};
int
returnedResults
=
0
;
cublasLtMatmulHeuristicResult_t
heuristicResult
=
{};
cublasLtEpilogue_t
epilogue
=
CUBLASLT_EPILOGUE_DEFAULT
;
// Create operation descriptor; see cublasLtMatmulDescAttributes_t
// for details about defaults; here we just set the transforms for
// A and B.
status
=
cublasLtMatmulDescInit
(
&
operationDesc
,
CUBLAS_COMPUTE_32F
,
CUDA_R_32F
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_TRANSA
,
&
transa
,
sizeof
(
transa
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_TRANSB
,
&
transb
,
sizeof
(
transa
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
if
(
use_bias
)
{
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_BIAS_POINTER
,
&
bias
,
sizeof
(
bias
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
goto
CLEANUP
;
}
epilogue
=
CUBLASLT_EPILOGUE_BIAS
;
}
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_EPILOGUE
,
&
epilogue
,
sizeof
(
epilogue
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
goto
CLEANUP
;
}
// Create matrix descriptors. Not setting any extra attributes.
status
=
cublasLtMatrixLayoutInit
(
&
Adesc
,
CUDA_R_16F
,
transa
==
CUBLAS_OP_N
?
m
:
k
,
transa
==
CUBLAS_OP_N
?
k
:
m
,
lda
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatrixLayoutInit
(
&
Bdesc
,
CUDA_R_16F
,
transb
==
CUBLAS_OP_N
?
k
:
n
,
transb
==
CUBLAS_OP_N
?
n
:
k
,
ldb
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatrixLayoutInit
(
&
Cdesc
,
CUDA_R_16F
,
m
,
n
,
ldc
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
// Create preference handle; In general, extra attributes can be
// used here to disable tensor ops or to make sure algo selected
// will work with badly aligned A, B, C. However, for simplicity
// here we assume A,B,C are always well aligned (e.g., directly
// come from cudaMalloc)
status
=
cublasLtMatmulPreferenceInit
(
&
preference
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatmulPreferenceSetAttribute
(
&
preference
,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES
,
&
workspaceSize
,
sizeof
(
workspaceSize
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
// We just need the best available heuristic to try and run matmul.
// There is no guarantee that this will work. For example, if A is
// badly aligned, you can request more (e.g. 32) algos and try to
// run them one by one until something works.
status
=
cublasLtMatmulAlgoGetHeuristic
(
ltHandle
,
&
operationDesc
,
&
Adesc
,
&
Bdesc
,
&
Cdesc
,
&
Cdesc
,
&
preference
,
1
,
&
heuristicResult
,
&
returnedResults
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
if
(
returnedResults
==
0
)
{
status
=
CUBLAS_STATUS_NOT_SUPPORTED
;
goto
CLEANUP
;
}
status
=
cublasLtMatmul
(
ltHandle
,
&
operationDesc
,
alpha
,
A
,
&
Adesc
,
B
,
&
Bdesc
,
beta
,
C
,
&
Cdesc
,
C
,
&
Cdesc
,
//&heuristicResult.algo,
NULL
,
workspace
,
workspaceSize
,
stream
);
CLEANUP:
// Descriptors are no longer needed as all GPU work was already
// enqueued.
return
status
==
CUBLAS_STATUS_SUCCESS
?
0
:
1
;
}
int
gemm_bias_lt
(
cublasLtHandle_t
ltHandle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
/* host pointer */
double
*
A
,
int
lda
,
double
*
B
,
int
ldb
,
const
float
*
beta
,
/* host pointer */
double
*
C
,
int
ldc
,
void
*
workspace
,
size_t
workspaceSize
,
cudaStream_t
stream
,
bool
use_bias
,
const
void
*
bias
)
{
return
1
;
}
int
gemm_bias_lt
(
cublasLtHandle_t
ltHandle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
/* host pointer */
float
*
A
,
int
lda
,
float
*
B
,
int
ldb
,
const
float
*
beta
,
/* host pointer */
float
*
C
,
int
ldc
,
void
*
workspace
,
size_t
workspaceSize
,
cudaStream_t
stream
,
bool
use_bias
,
const
void
*
bias
)
{
cublasStatus_t
status
=
CUBLAS_STATUS_SUCCESS
;
cublasLtMatmulDescOpaque_t
operationDesc
=
{};
cublasLtMatrixLayoutOpaque_t
Adesc
=
{},
Bdesc
=
{},
Cdesc
=
{};
cublasLtMatmulPreferenceOpaque_t
preference
=
{};
int
returnedResults
=
0
;
cublasLtMatmulHeuristicResult_t
heuristicResult
=
{};
cublasLtEpilogue_t
epilogue
=
CUBLASLT_EPILOGUE_DEFAULT
;
// Create operation descriptor; see cublasLtMatmulDescAttributes_t
// for details about defaults; here we just set the transforms for
// A and B.
status
=
cublasLtMatmulDescInit
(
&
operationDesc
,
CUBLAS_COMPUTE_32F
,
CUDA_R_32F
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_TRANSA
,
&
transa
,
sizeof
(
transa
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_TRANSB
,
&
transb
,
sizeof
(
transa
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
if
(
use_bias
)
{
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_BIAS_POINTER
,
&
bias
,
sizeof
(
bias
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
goto
CLEANUP
;
}
epilogue
=
CUBLASLT_EPILOGUE_BIAS
;
}
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_EPILOGUE
,
&
epilogue
,
sizeof
(
epilogue
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
goto
CLEANUP
;
}
// Create matrix descriptors. Not setting any extra attributes.
status
=
cublasLtMatrixLayoutInit
(
&
Adesc
,
CUDA_R_32F
,
transa
==
CUBLAS_OP_N
?
m
:
k
,
transa
==
CUBLAS_OP_N
?
k
:
m
,
lda
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatrixLayoutInit
(
&
Bdesc
,
CUDA_R_32F
,
transb
==
CUBLAS_OP_N
?
k
:
n
,
transb
==
CUBLAS_OP_N
?
n
:
k
,
ldb
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatrixLayoutInit
(
&
Cdesc
,
CUDA_R_32F
,
m
,
n
,
ldc
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
// Create preference handle; In general, extra attributes can be
// used here to disable tensor ops or to make sure algo selected
// will work with badly aligned A, B, C. However, for simplicity
// here we assume A,B,C are always well aligned (e.g., directly
// come from cudaMalloc)
status
=
cublasLtMatmulPreferenceInit
(
&
preference
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatmulPreferenceSetAttribute
(
&
preference
,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES
,
&
workspaceSize
,
sizeof
(
workspaceSize
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
// We just need the best available heuristic to try and run matmul.
// There is no guarantee that this will work. For example, if A is
// badly aligned, you can request more (e.g. 32) algos and try to
// run them one by one until something works.
status
=
cublasLtMatmulAlgoGetHeuristic
(
ltHandle
,
&
operationDesc
,
&
Adesc
,
&
Bdesc
,
&
Cdesc
,
&
Cdesc
,
&
preference
,
1
,
&
heuristicResult
,
&
returnedResults
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
if
(
returnedResults
==
0
)
{
status
=
CUBLAS_STATUS_NOT_SUPPORTED
;
goto
CLEANUP
;
}
status
=
cublasLtMatmul
(
ltHandle
,
&
operationDesc
,
alpha
,
A
,
&
Adesc
,
B
,
&
Bdesc
,
beta
,
C
,
&
Cdesc
,
C
,
&
Cdesc
,
&
heuristicResult
.
algo
,
workspace
,
workspaceSize
,
stream
);
CLEANUP:
// Descriptors are no longer needed as all GPU work was already
// enqueued.
return
status
==
CUBLAS_STATUS_SUCCESS
?
0
:
1
;
}
int
gemm_bias_gelu_lt
(
cublasLtHandle_t
ltHandle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
/* host pointer */
at
::
Half
*
A
,
int
lda
,
at
::
Half
*
B
,
int
ldb
,
const
float
*
beta
,
/* host pointer */
at
::
Half
*
C
,
int64_t
ldc
,
void
*
workspace
,
size_t
workspaceSize
,
cudaStream_t
stream
,
bool
use_bias
,
const
void
*
gelu_in
,
const
void
*
bias
)
{
cublasStatus_t
status
=
CUBLAS_STATUS_SUCCESS
;
cublasLtMatmulDescOpaque_t
operationDesc
=
{};
cublasLtMatrixLayoutOpaque_t
Adesc
=
{},
Bdesc
=
{},
Cdesc
=
{};
cublasLtMatmulPreferenceOpaque_t
preference
=
{};
int
returnedResults
=
0
;
cublasLtMatmulHeuristicResult_t
heuristicResult
=
{};
cublasLtEpilogue_t
epilogue
=
CUBLASLT_EPILOGUE_GELU_AUX
;
// Create operation descriptor; see cublasLtMatmulDescAttributes_t
// for details about defaults; here we just set the transforms for
// A and B.
status
=
cublasLtMatmulDescInit
(
&
operationDesc
,
CUBLAS_COMPUTE_32F
,
CUDA_R_32F
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_TRANSA
,
&
transa
,
sizeof
(
transa
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_TRANSB
,
&
transb
,
sizeof
(
transa
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER
,
&
gelu_in
,
sizeof
(
gelu_in
));
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD
,
&
ldc
,
sizeof
(
ldc
));
if
(
use_bias
)
{
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_BIAS_POINTER
,
&
bias
,
sizeof
(
bias
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
goto
CLEANUP
;
}
epilogue
=
CUBLASLT_EPILOGUE_GELU_AUX_BIAS
;
}
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_EPILOGUE
,
&
epilogue
,
sizeof
(
epilogue
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
goto
CLEANUP
;
}
// Create matrix descriptors. Not setting any extra attributes.
status
=
cublasLtMatrixLayoutInit
(
&
Adesc
,
CUDA_R_16F
,
transa
==
CUBLAS_OP_N
?
m
:
k
,
transa
==
CUBLAS_OP_N
?
k
:
m
,
lda
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatrixLayoutInit
(
&
Bdesc
,
CUDA_R_16F
,
transb
==
CUBLAS_OP_N
?
k
:
n
,
transb
==
CUBLAS_OP_N
?
n
:
k
,
ldb
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatrixLayoutInit
(
&
Cdesc
,
CUDA_R_16F
,
m
,
n
,
ldc
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
// Create preference handle; In general, extra attributes can be
// used here to disable tensor ops or to make sure algo selected
// will work with badly aligned A, B, C. However, for simplicity
// here we assume A,B,C are always well aligned (e.g., directly
// come from cudaMalloc)
status
=
cublasLtMatmulPreferenceInit
(
&
preference
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatmulPreferenceSetAttribute
(
&
preference
,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES
,
&
workspaceSize
,
sizeof
(
workspaceSize
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
// We just need the best available heuristic to try and run matmul.
// There is no guarantee that this will work. For example, if A is
// badly aligned, you can request more (e.g. 32) algos and try to
// run them one by one until something works.
status
=
cublasLtMatmulAlgoGetHeuristic
(
ltHandle
,
&
operationDesc
,
&
Adesc
,
&
Bdesc
,
&
Cdesc
,
&
Cdesc
,
&
preference
,
1
,
&
heuristicResult
,
&
returnedResults
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
if
(
returnedResults
==
0
)
{
status
=
CUBLAS_STATUS_NOT_SUPPORTED
;
goto
CLEANUP
;
}
status
=
cublasLtMatmul
(
ltHandle
,
&
operationDesc
,
alpha
,
A
,
&
Adesc
,
B
,
&
Bdesc
,
beta
,
C
,
&
Cdesc
,
C
,
&
Cdesc
,
//&heuristicResult.algo,
NULL
,
workspace
,
workspaceSize
,
stream
);
CLEANUP:
// Descriptors are no longer needed as all GPU work was already
// enqueued.
return
status
==
CUBLAS_STATUS_SUCCESS
?
0
:
1
;
}
int
gemm_bias_gelu_lt
(
cublasLtHandle_t
ltHandle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
/* host pointer */
double
*
A
,
int
lda
,
double
*
B
,
int
ldb
,
const
float
*
beta
,
/* host pointer */
double
*
C
,
int
ldc
,
void
*
workspace
,
size_t
workspaceSize
,
cudaStream_t
stream
,
bool
use_bias
,
const
void
*
gelu_in
,
const
void
*
bias
)
{
return
1
;
}
int
gemm_bias_gelu_lt
(
cublasLtHandle_t
ltHandle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
/* host pointer */
float
*
A
,
int
lda
,
float
*
B
,
int
ldb
,
const
float
*
beta
,
/* host pointer */
float
*
C
,
int64_t
ldc
,
void
*
workspace
,
size_t
workspaceSize
,
cudaStream_t
stream
,
bool
use_bias
,
const
void
*
gelu_in
,
const
void
*
bias
)
{
cublasStatus_t
status
=
CUBLAS_STATUS_SUCCESS
;
cublasLtMatmulDescOpaque_t
operationDesc
=
{};
cublasLtMatrixLayoutOpaque_t
Adesc
=
{},
Bdesc
=
{},
Cdesc
=
{};
cublasLtMatmulPreferenceOpaque_t
preference
=
{};
int
returnedResults
=
0
;
cublasLtMatmulHeuristicResult_t
heuristicResult
=
{};
cublasLtEpilogue_t
epilogue
=
CUBLASLT_EPILOGUE_GELU_AUX
;
// Create operation descriptor; see cublasLtMatmulDescAttributes_t
// for details about defaults; here we just set the transforms for
// A and B.
status
=
cublasLtMatmulDescInit
(
&
operationDesc
,
CUBLAS_COMPUTE_32F
,
CUDA_R_32F
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_TRANSA
,
&
transa
,
sizeof
(
transa
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_TRANSB
,
&
transb
,
sizeof
(
transa
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER
,
&
gelu_in
,
sizeof
(
gelu_in
));
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD
,
&
ldc
,
sizeof
(
ldc
));
if
(
use_bias
)
{
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_BIAS_POINTER
,
&
bias
,
sizeof
(
bias
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
goto
CLEANUP
;
}
epilogue
=
CUBLASLT_EPILOGUE_GELU_AUX_BIAS
;
}
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_EPILOGUE
,
&
epilogue
,
sizeof
(
epilogue
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
goto
CLEANUP
;
}
// Create matrix descriptors. Not setting any extra attributes.
status
=
cublasLtMatrixLayoutInit
(
&
Adesc
,
CUDA_R_32F
,
transa
==
CUBLAS_OP_N
?
m
:
k
,
transa
==
CUBLAS_OP_N
?
k
:
m
,
lda
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatrixLayoutInit
(
&
Bdesc
,
CUDA_R_32F
,
transb
==
CUBLAS_OP_N
?
k
:
n
,
transb
==
CUBLAS_OP_N
?
n
:
k
,
ldb
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatrixLayoutInit
(
&
Cdesc
,
CUDA_R_32F
,
m
,
n
,
ldc
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
// Create preference handle; In general, extra attributes can be
// used here to disable tensor ops or to make sure algo selected
// will work with badly aligned A, B, C. However, for simplicity
// here we assume A,B,C are always well aligned (e.g., directly
// come from cudaMalloc)
status
=
cublasLtMatmulPreferenceInit
(
&
preference
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatmulPreferenceSetAttribute
(
&
preference
,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES
,
&
workspaceSize
,
sizeof
(
workspaceSize
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
// We just need the best available heuristic to try and run matmul.
// There is no guarantee that this will work. For example, if A is
// badly aligned, you can request more (e.g. 32) algos and try to
// run them one by one until something works.
status
=
cublasLtMatmulAlgoGetHeuristic
(
ltHandle
,
&
operationDesc
,
&
Adesc
,
&
Bdesc
,
&
Cdesc
,
&
Cdesc
,
&
preference
,
1
,
&
heuristicResult
,
&
returnedResults
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
if
(
returnedResults
==
0
)
{
status
=
CUBLAS_STATUS_NOT_SUPPORTED
;
goto
CLEANUP
;
}
status
=
cublasLtMatmul
(
ltHandle
,
&
operationDesc
,
alpha
,
A
,
&
Adesc
,
B
,
&
Bdesc
,
beta
,
C
,
&
Cdesc
,
C
,
&
Cdesc
,
//&heuristicResult.algo,
NULL
,
workspace
,
workspaceSize
,
stream
);
CLEANUP:
// Descriptors are no longer needed as all GPU work was already
// enqueued.
return
status
==
CUBLAS_STATUS_SUCCESS
?
0
:
1
;
}
int
gemm_bgradb_lt
(
cublasLtHandle_t
ltHandle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
/* host pointer */
at
::
Half
*
A
,
int
lda
,
at
::
Half
*
B
,
int
ldb
,
const
float
*
beta
,
/* host pointer */
at
::
Half
*
C
,
int
ldc
,
void
*
workspace
,
size_t
workspaceSize
,
cudaStream_t
stream
,
bool
use_bias
,
const
void
*
bgrad
)
{
cublasStatus_t
status
=
CUBLAS_STATUS_SUCCESS
;
cublasLtMatmulDescOpaque_t
operationDesc
=
{};
cublasLtMatrixLayoutOpaque_t
Adesc
=
{},
Bdesc
=
{},
Cdesc
=
{};
cublasLtMatmulPreferenceOpaque_t
preference
=
{};
int
returnedResults
=
0
;
cublasLtMatmulHeuristicResult_t
heuristicResult
=
{};
cublasLtEpilogue_t
epilogue
=
CUBLASLT_EPILOGUE_DEFAULT
;
// Create operation descriptor; see cublasLtMatmulDescAttributes_t
// for details about defaults; here we just set the transforms for
// A and B.
status
=
cublasLtMatmulDescInit
(
&
operationDesc
,
CUBLAS_COMPUTE_32F
,
CUDA_R_32F
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_TRANSA
,
&
transa
,
sizeof
(
transa
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_TRANSB
,
&
transb
,
sizeof
(
transa
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
if
(
use_bias
)
{
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_BIAS_POINTER
,
&
bgrad
,
sizeof
(
bgrad
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
goto
CLEANUP
;
}
epilogue
=
CUBLASLT_EPILOGUE_BGRADB
;
}
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_EPILOGUE
,
&
epilogue
,
sizeof
(
epilogue
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
goto
CLEANUP
;
}
// Create matrix descriptors. Not setting any extra attributes.
status
=
cublasLtMatrixLayoutInit
(
&
Adesc
,
CUDA_R_16F
,
transa
==
CUBLAS_OP_N
?
m
:
k
,
transa
==
CUBLAS_OP_N
?
k
:
m
,
lda
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatrixLayoutInit
(
&
Bdesc
,
CUDA_R_16F
,
transb
==
CUBLAS_OP_N
?
k
:
n
,
transb
==
CUBLAS_OP_N
?
n
:
k
,
ldb
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatrixLayoutInit
(
&
Cdesc
,
CUDA_R_16F
,
m
,
n
,
ldc
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
// Create preference handle; In general, extra attributes can be
// used here to disable tensor ops or to make sure algo selected
// will work with badly aligned A, B, C. However, for simplicity
// here we assume A,B,C are always well aligned (e.g., directly
// come from cudaMalloc)
status
=
cublasLtMatmulPreferenceInit
(
&
preference
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatmulPreferenceSetAttribute
(
&
preference
,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES
,
&
workspaceSize
,
sizeof
(
workspaceSize
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
// We just need the best available heuristic to try and run matmul.
// There is no guarantee that this will work. For example, if A is
// badly aligned, you can request more (e.g. 32) algos and try to
// run them one by one until something works.
status
=
cublasLtMatmulAlgoGetHeuristic
(
ltHandle
,
&
operationDesc
,
&
Adesc
,
&
Bdesc
,
&
Cdesc
,
&
Cdesc
,
&
preference
,
1
,
&
heuristicResult
,
&
returnedResults
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
if
(
returnedResults
==
0
)
{
status
=
CUBLAS_STATUS_NOT_SUPPORTED
;
goto
CLEANUP
;
}
status
=
cublasLtMatmul
(
ltHandle
,
&
operationDesc
,
alpha
,
A
,
&
Adesc
,
B
,
&
Bdesc
,
beta
,
C
,
&
Cdesc
,
C
,
&
Cdesc
,
//&heuristicResult.algo,
NULL
,
workspace
,
workspaceSize
,
stream
);
CLEANUP:
// Descriptors are no longer needed as all GPU work was already
// enqueued.
return
status
==
CUBLAS_STATUS_SUCCESS
?
0
:
1
;
}
int
gemm_bgradb_lt
(
cublasLtHandle_t
ltHandle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
/* host pointer */
double
*
A
,
int
lda
,
double
*
B
,
int
ldb
,
const
float
*
beta
,
/* host pointer */
double
*
C
,
int
ldc
,
void
*
workspace
,
size_t
workspaceSize
,
cudaStream_t
stream
,
bool
use_bias
,
const
void
*
bgrad
)
{
return
1
;
}
int
gemm_bgradb_lt
(
cublasLtHandle_t
ltHandle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
/* host pointer */
float
*
A
,
int
lda
,
float
*
B
,
int
ldb
,
const
float
*
beta
,
/* host pointer */
float
*
C
,
int
ldc
,
void
*
workspace
,
size_t
workspaceSize
,
cudaStream_t
stream
,
bool
use_bias
,
const
void
*
bgrad
)
{
cublasStatus_t
status
=
CUBLAS_STATUS_SUCCESS
;
cublasLtMatmulDescOpaque_t
operationDesc
=
{};
cublasLtMatrixLayoutOpaque_t
Adesc
=
{},
Bdesc
=
{},
Cdesc
=
{};
cublasLtMatmulPreferenceOpaque_t
preference
=
{};
int
returnedResults
=
0
;
cublasLtMatmulHeuristicResult_t
heuristicResult
=
{};
cublasLtEpilogue_t
epilogue
=
CUBLASLT_EPILOGUE_DEFAULT
;
// Create operation descriptor; see cublasLtMatmulDescAttributes_t
// for details about defaults; here we just set the transforms for
// A and B.
status
=
cublasLtMatmulDescInit
(
&
operationDesc
,
CUBLAS_COMPUTE_32F
,
CUDA_R_32F
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_TRANSA
,
&
transa
,
sizeof
(
transa
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_TRANSB
,
&
transb
,
sizeof
(
transa
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
if
(
use_bias
)
{
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_BIAS_POINTER
,
&
bgrad
,
sizeof
(
bgrad
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
goto
CLEANUP
;
}
epilogue
=
CUBLASLT_EPILOGUE_BGRADB
;
}
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_EPILOGUE
,
&
epilogue
,
sizeof
(
epilogue
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
goto
CLEANUP
;
}
// Create matrix descriptors. Not setting any extra attributes.
status
=
cublasLtMatrixLayoutInit
(
&
Adesc
,
CUDA_R_32F
,
transa
==
CUBLAS_OP_N
?
m
:
k
,
transa
==
CUBLAS_OP_N
?
k
:
m
,
lda
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatrixLayoutInit
(
&
Bdesc
,
CUDA_R_32F
,
transb
==
CUBLAS_OP_N
?
k
:
n
,
transb
==
CUBLAS_OP_N
?
n
:
k
,
ldb
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatrixLayoutInit
(
&
Cdesc
,
CUDA_R_32F
,
m
,
n
,
ldc
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
// Create preference handle; In general, extra attributes can be
// used here to disable tensor ops or to make sure algo selected
// will work with badly aligned A, B, C. However, for simplicity
// here we assume A,B,C are always well aligned (e.g., directly
// come from cudaMalloc)
status
=
cublasLtMatmulPreferenceInit
(
&
preference
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatmulPreferenceSetAttribute
(
&
preference
,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES
,
&
workspaceSize
,
sizeof
(
workspaceSize
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
// We just need the best available heuristic to try and run matmul.
// There is no guarantee that this will work. For example, if A is
// badly aligned, you can request more (e.g. 32) algos and try to
// run them one by one until something works.
status
=
cublasLtMatmulAlgoGetHeuristic
(
ltHandle
,
&
operationDesc
,
&
Adesc
,
&
Bdesc
,
&
Cdesc
,
&
Cdesc
,
&
preference
,
1
,
&
heuristicResult
,
&
returnedResults
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
if
(
returnedResults
==
0
)
{
status
=
CUBLAS_STATUS_NOT_SUPPORTED
;
goto
CLEANUP
;
}
status
=
cublasLtMatmul
(
ltHandle
,
&
operationDesc
,
alpha
,
A
,
&
Adesc
,
B
,
&
Bdesc
,
beta
,
C
,
&
Cdesc
,
C
,
&
Cdesc
,
&
heuristicResult
.
algo
,
workspace
,
workspaceSize
,
stream
);
CLEANUP:
// Descriptors are no longer needed as all GPU work was already
// enqueued.
return
status
==
CUBLAS_STATUS_SUCCESS
?
0
:
1
;
}
int
gemm_dgelu_bgradb_lt
(
cublasLtHandle_t
ltHandle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
/* host pointer */
at
::
Half
*
A
,
int
lda
,
at
::
Half
*
B
,
int
ldb
,
const
float
*
beta
,
/* host pointer */
at
::
Half
*
C
,
int64_t
ldc
,
void
*
workspace
,
size_t
workspaceSize
,
cudaStream_t
stream
,
const
void
*
gelu_in
,
const
void
*
bgrad
)
{
cublasStatus_t
status
=
CUBLAS_STATUS_SUCCESS
;
cublasLtMatmulDescOpaque_t
operationDesc
=
{};
cublasLtMatrixLayoutOpaque_t
Adesc
=
{},
Bdesc
=
{},
Cdesc
=
{};
cublasLtMatmulPreferenceOpaque_t
preference
=
{};
int
returnedResults
=
0
;
cublasLtMatmulHeuristicResult_t
heuristicResult
=
{};
cublasLtEpilogue_t
epilogue
=
CUBLASLT_EPILOGUE_DGELU_BGRAD
;
// Create operation descriptor; see cublasLtMatmulDescAttributes_t
// for details about defaults; here we just set the transforms for
// A and B.
status
=
cublasLtMatmulDescInit
(
&
operationDesc
,
CUBLAS_COMPUTE_32F
,
CUDA_R_32F
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_TRANSA
,
&
transa
,
sizeof
(
transa
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_TRANSB
,
&
transb
,
sizeof
(
transa
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_BIAS_POINTER
,
&
bgrad
,
sizeof
(
bgrad
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
goto
CLEANUP
;
}
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER
,
&
gelu_in
,
sizeof
(
gelu_in
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
goto
CLEANUP
;
}
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD
,
&
ldc
,
sizeof
(
ldc
));
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_EPILOGUE
,
&
epilogue
,
sizeof
(
epilogue
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
goto
CLEANUP
;
}
// Create matrix descriptors. Not setting any extra attributes.
status
=
cublasLtMatrixLayoutInit
(
&
Adesc
,
CUDA_R_16F
,
transa
==
CUBLAS_OP_N
?
m
:
k
,
transa
==
CUBLAS_OP_N
?
k
:
m
,
lda
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatrixLayoutInit
(
&
Bdesc
,
CUDA_R_16F
,
transb
==
CUBLAS_OP_N
?
k
:
n
,
transb
==
CUBLAS_OP_N
?
n
:
k
,
ldb
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatrixLayoutInit
(
&
Cdesc
,
CUDA_R_16F
,
m
,
n
,
ldc
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
// Create preference handle; In general, extra attributes can be
// used here to disable tensor ops or to make sure algo selected
// will work with badly aligned A, B, C. However, for simplicity
// here we assume A,B,C are always well aligned (e.g., directly
// come from cudaMalloc)
status
=
cublasLtMatmulPreferenceInit
(
&
preference
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatmulPreferenceSetAttribute
(
&
preference
,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES
,
&
workspaceSize
,
sizeof
(
workspaceSize
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
// We just need the best available heuristic to try and run matmul.
// There is no guarantee that this will work. For example, if A is
// badly aligned, you can request more (e.g. 32) algos and try to
// run them one by one until something works.
status
=
cublasLtMatmulAlgoGetHeuristic
(
ltHandle
,
&
operationDesc
,
&
Adesc
,
&
Bdesc
,
&
Cdesc
,
&
Cdesc
,
&
preference
,
1
,
&
heuristicResult
,
&
returnedResults
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
if
(
returnedResults
==
0
)
{
status
=
CUBLAS_STATUS_NOT_SUPPORTED
;
goto
CLEANUP
;
}
status
=
cublasLtMatmul
(
ltHandle
,
&
operationDesc
,
alpha
,
A
,
&
Adesc
,
B
,
&
Bdesc
,
beta
,
C
,
&
Cdesc
,
C
,
&
Cdesc
,
//&heuristicResult.algo,
NULL
,
workspace
,
workspaceSize
,
stream
);
CLEANUP:
// Descriptors are no longer needed as all GPU work was already
// enqueued.
return
status
==
CUBLAS_STATUS_SUCCESS
?
0
:
1
;
}
int
gemm_dgelu_bgradb_lt
(
cublasLtHandle_t
ltHandle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
/* host pointer */
double
*
A
,
int
lda
,
double
*
B
,
int
ldb
,
const
float
*
beta
,
/* host pointer */
double
*
C
,
int
ldc
,
void
*
workspace
,
size_t
workspaceSize
,
cudaStream_t
stream
,
const
void
*
gelu_in
,
const
void
*
bgrad
)
{
return
1
;
}
int
gemm_dgelu_bgradb_lt
(
cublasLtHandle_t
ltHandle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
/* host pointer */
float
*
A
,
int
lda
,
float
*
B
,
int
ldb
,
const
float
*
beta
,
/* host pointer */
float
*
C
,
int64_t
ldc
,
void
*
workspace
,
size_t
workspaceSize
,
cudaStream_t
stream
,
const
void
*
gelu_in
,
const
void
*
bgrad
)
{
cublasStatus_t
status
=
CUBLAS_STATUS_SUCCESS
;
cublasLtMatmulDescOpaque_t
operationDesc
=
{};
cublasLtMatrixLayoutOpaque_t
Adesc
=
{},
Bdesc
=
{},
Cdesc
=
{};
cublasLtMatmulPreferenceOpaque_t
preference
=
{};
int
returnedResults
=
0
;
cublasLtMatmulHeuristicResult_t
heuristicResult
=
{};
cublasLtEpilogue_t
epilogue
=
CUBLASLT_EPILOGUE_DGELU_BGRAD
;
// Create operation descriptor; see cublasLtMatmulDescAttributes_t
// for details about defaults; here we just set the transforms for
// A and B.
status
=
cublasLtMatmulDescInit
(
&
operationDesc
,
CUBLAS_COMPUTE_32F
,
CUDA_R_32F
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_TRANSA
,
&
transa
,
sizeof
(
transa
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_TRANSB
,
&
transb
,
sizeof
(
transa
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_BIAS_POINTER
,
&
bgrad
,
sizeof
(
bgrad
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
goto
CLEANUP
;
}
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER
,
&
gelu_in
,
sizeof
(
gelu_in
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
goto
CLEANUP
;
}
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD
,
&
ldc
,
sizeof
(
ldc
));
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_EPILOGUE
,
&
epilogue
,
sizeof
(
epilogue
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
goto
CLEANUP
;
}
// Create matrix descriptors. Not setting any extra attributes.
status
=
cublasLtMatrixLayoutInit
(
&
Adesc
,
CUDA_R_32F
,
transa
==
CUBLAS_OP_N
?
m
:
k
,
transa
==
CUBLAS_OP_N
?
k
:
m
,
lda
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatrixLayoutInit
(
&
Bdesc
,
CUDA_R_32F
,
transb
==
CUBLAS_OP_N
?
k
:
n
,
transb
==
CUBLAS_OP_N
?
n
:
k
,
ldb
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatrixLayoutInit
(
&
Cdesc
,
CUDA_R_32F
,
m
,
n
,
ldc
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
// Create preference handle; In general, extra attributes can be
// used here to disable tensor ops or to make sure algo selected
// will work with badly aligned A, B, C. However, for simplicity
// here we assume A,B,C are always well aligned (e.g., directly
// come from cudaMalloc)
status
=
cublasLtMatmulPreferenceInit
(
&
preference
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatmulPreferenceSetAttribute
(
&
preference
,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES
,
&
workspaceSize
,
sizeof
(
workspaceSize
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
// We just need the best available heuristic to try and run matmul.
// There is no guarantee that this will work. For example, if A is
// badly aligned, you can request more (e.g. 32) algos and try to
// run them one by one until something works.
status
=
cublasLtMatmulAlgoGetHeuristic
(
ltHandle
,
&
operationDesc
,
&
Adesc
,
&
Bdesc
,
&
Cdesc
,
&
Cdesc
,
&
preference
,
1
,
&
heuristicResult
,
&
returnedResults
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
if
(
returnedResults
==
0
)
{
status
=
CUBLAS_STATUS_NOT_SUPPORTED
;
goto
CLEANUP
;
}
status
=
cublasLtMatmul
(
ltHandle
,
&
operationDesc
,
alpha
,
A
,
&
Adesc
,
B
,
&
Bdesc
,
beta
,
C
,
&
Cdesc
,
C
,
&
Cdesc
,
//&heuristicResult.algo,
NULL
,
workspace
,
workspaceSize
,
stream
);
CLEANUP:
// Descriptors are no longer needed as all GPU work was already
// enqueued.
return
status
==
CUBLAS_STATUS_SUCCESS
?
0
:
1
;
}
#endif
template
<
typename
T
>
int
linear_bias_forward_cuda
(
at
::
Tensor
input
,
T
*
weight
,
at
::
Tensor
bias
,
int
in_features
,
int
batch_size
,
int
out_features
,
at
::
Tensor
output
,
void
*
lt_workspace
)
{
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
// Get the stream from cublas handle to reuse for biasReLU kernel.
cudaStream_t
stream
;
cublasGetStream
(
handle
,
&
stream
);
const
float
alpha
=
1.0
;
const
float
beta_zero
=
0.0
;
const
float
beta_one
=
1.0
;
int
status
=
1
;
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
status
=
gemm_bias_lt
(
(
cublasLtHandle_t
)
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
out_features
,
batch_size
,
in_features
,
&
alpha
,
/* host pointer */
weight
,
in_features
,
input
.
data_ptr
<
T
>
(),
in_features
,
&
beta_zero
,
/* host pointer */
output
.
data_ptr
<
T
>
(),
out_features
,
lt_workspace
,
1
<<
22
,
stream
,
true
,
static_cast
<
const
void
*>
(
bias
.
data_ptr
<
T
>
()));
#endif
if
(
status
!=
0
){
output
.
copy_
(
bias
);
status
=
gemm_bias
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
out_features
,
batch_size
,
in_features
,
&
alpha
,
weight
,
in_features
,
input
.
data_ptr
<
T
>
(),
in_features
,
&
beta_one
,
output
.
data_ptr
<
T
>
(),
out_features
);
}
return
status
;
}
template
<
typename
T
>
int
linear_bias_backward_cuda
(
T
*
input
,
T
*
weight
,
T
*
d_output
,
int
in_features
,
int
batch_size
,
int
out_features
,
T
*
d_weight
,
T
*
d_bias
,
T
*
d_input
,
void
*
lt_workspace
)
{
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
// Get the stream from cublas handle to reuse for biasReLU kernel.
cudaStream_t
stream
;
cublasGetStream
(
handle
,
&
stream
);
const
float
alpha
=
1.0
;
const
float
beta_zero
=
0.0
;
int
status
=
1
;
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
status
=
gemm_bgradb_lt
(
(
cublasLtHandle_t
)
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
in_features
,
out_features
,
batch_size
,
&
alpha
,
/* host pointer */
input
,
in_features
,
d_output
,
out_features
,
&
beta_zero
,
/* host pointer */
d_weight
,
in_features
,
lt_workspace
,
1
<<
22
,
stream
,
true
,
static_cast
<
const
void
*>
(
d_bias
));
#endif
if
(
status
!=
0
){
status
=
gemm_bias
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
in_features
,
out_features
,
batch_size
,
&
alpha
,
input
,
in_features
,
d_output
,
out_features
,
&
beta_zero
,
d_weight
,
in_features
);
}
status
=
gemm_bias
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
in_features
,
batch_size
,
out_features
,
&
alpha
,
weight
,
in_features
,
d_output
,
out_features
,
&
beta_zero
,
d_input
,
in_features
);
return
status
;
}
template
<
typename
T
>
int
linear_gelu_linear_forward_cuda
(
T
*
input
,
T
*
weight1
,
T
*
bias1
,
T
*
weight2
,
T
*
bias2
,
int
in_features
,
int
hidden_features
,
int
batch_size
,
int
out_features
,
T
*
output1
,
T
*
output2
,
T
*
gelu_in
,
void
*
lt_workspace
)
{
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
// Get the stream from cublas handle to reuse for biasReLU kernel.
cudaStream_t
stream
;
cublasGetStream
(
handle
,
&
stream
);
const
float
alpha
=
1.0
;
const
float
beta_zero
=
0.0
;
int
status
=
1
;
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
status
=
gemm_bias_gelu_lt
(
(
cublasLtHandle_t
)
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
hidden_features
,
batch_size
,
in_features
,
&
alpha
,
/* host pointer */
weight1
,
in_features
,
input
,
in_features
,
&
beta_zero
,
/* host pointer */
output1
,
hidden_features
,
lt_workspace
,
1
<<
22
,
stream
,
true
,
static_cast
<
const
void
*>
(
gelu_in
),
static_cast
<
const
void
*>
(
bias1
));
status
=
gemm_bias_lt
(
(
cublasLtHandle_t
)
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
out_features
,
batch_size
,
hidden_features
,
&
alpha
,
/* host pointer */
weight2
,
hidden_features
,
output1
,
hidden_features
,
&
beta_zero
,
/* host pointer */
output2
,
out_features
,
lt_workspace
,
1
<<
22
,
stream
,
true
,
static_cast
<
const
void
*>
(
bias2
));
return
status
;
#else
return
1
;
#endif
}
template
<
typename
T
>
int
linear_gelu_linear_backward_cuda
(
T
*
input
,
T
*
gelu_in
,
T
*
output1
,
T
*
weight1
,
T
*
weight2
,
T
*
d_output1
,
T
*
d_output2
,
int
in_features
,
int
batch_size
,
int
hidden_features
,
int
out_features
,
T
*
d_weight1
,
T
*
d_weight2
,
T
*
d_bias1
,
T
*
d_bias2
,
T
*
d_input
,
void
*
lt_workspace
)
{
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
// Get the stream from cublas handle to reuse for biasReLU kernel.
cudaStream_t
stream
;
cublasGetStream
(
handle
,
&
stream
);
const
float
alpha
=
1.0
;
const
float
beta_zero
=
0.0
;
int
status
=
1
;
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
//wgrad for first gemm
status
=
gemm_bgradb_lt
(
(
cublasLtHandle_t
)
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
hidden_features
,
out_features
,
batch_size
,
&
alpha
,
/* host pointer */
output1
,
hidden_features
,
d_output2
,
out_features
,
&
beta_zero
,
/* host pointer */
d_weight2
,
hidden_features
,
lt_workspace
,
1
<<
22
,
stream
,
true
,
static_cast
<
const
void
*>
(
d_bias2
));
//dgrad for second GEMM
status
=
gemm_dgelu_bgradb_lt
(
(
cublasLtHandle_t
)
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
hidden_features
,
batch_size
,
out_features
,
&
alpha
,
/* host pointer */
weight2
,
hidden_features
,
d_output2
,
out_features
,
&
beta_zero
,
/* host pointer */
d_output1
,
hidden_features
,
lt_workspace
,
1
<<
22
,
stream
,
static_cast
<
const
void
*>
(
gelu_in
),
static_cast
<
const
void
*>
(
d_bias1
));
//wgrad for the first GEMM
status
=
gemm_bias
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
in_features
,
hidden_features
,
batch_size
,
&
alpha
,
input
,
in_features
,
d_output1
,
hidden_features
,
&
beta_zero
,
d_weight1
,
in_features
);
//dgrad for the first GEMM
status
=
gemm_bias
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
in_features
,
batch_size
,
hidden_features
,
&
alpha
,
weight1
,
in_features
,
d_output1
,
hidden_features
,
&
beta_zero
,
d_input
,
in_features
);
#endif
return
status
;
}
template
int
linear_bias_forward_cuda
<
at
::
Half
>(
at
::
Tensor
input
,
at
::
Half
*
weight
,
at
::
Tensor
bias
,
int
in_features
,
int
batch_size
,
int
out_features
,
at
::
Tensor
output
,
void
*
lt_workspace
);
template
int
linear_bias_forward_cuda
<
float
>(
at
::
Tensor
input
,
float
*
weight
,
at
::
Tensor
bias
,
int
in_features
,
int
batch_size
,
int
out_features
,
at
::
Tensor
output
,
void
*
lt_workspace
);
template
int
linear_bias_forward_cuda
<
double
>(
at
::
Tensor
input
,
double
*
weight
,
at
::
Tensor
bias
,
int
in_features
,
int
batch_size
,
int
out_features
,
at
::
Tensor
output
,
void
*
lt_workspace
);
template
int
linear_bias_backward_cuda
<
at
::
Half
>(
at
::
Half
*
input
,
at
::
Half
*
weight
,
at
::
Half
*
d_output
,
int
in_features
,
int
batch_size
,
int
out_features
,
at
::
Half
*
d_weight
,
at
::
Half
*
d_bias
,
at
::
Half
*
d_input
,
void
*
lt_workspace
)
;
template
int
linear_bias_backward_cuda
<
float
>(
float
*
input
,
float
*
weight
,
float
*
d_output
,
int
in_features
,
int
batch_size
,
int
out_features
,
float
*
d_weight
,
float
*
d_bias
,
float
*
d_input
,
void
*
lt_workspace
)
;
template
int
linear_bias_backward_cuda
<
double
>(
double
*
input
,
double
*
weight
,
double
*
d_output
,
int
in_features
,
int
batch_size
,
int
out_features
,
double
*
d_weight
,
double
*
d_bias
,
double
*
d_input
,
void
*
lt_workspace
)
;
template
int
linear_gelu_linear_forward_cuda
<
at
::
Half
>(
at
::
Half
*
input
,
at
::
Half
*
weight1
,
at
::
Half
*
bias1
,
at
::
Half
*
weight2
,
at
::
Half
*
bias2
,
int
in_features
,
int
hidden_features
,
int
batch_size
,
int
out_features
,
at
::
Half
*
output1
,
at
::
Half
*
output2
,
at
::
Half
*
gelu_in
,
void
*
lt_workspace
)
;
template
int
linear_gelu_linear_forward_cuda
<
float
>(
float
*
input
,
float
*
weight1
,
float
*
bias1
,
float
*
weight2
,
float
*
bias2
,
int
in_features
,
int
hidden_features
,
int
batch_size
,
int
out_features
,
float
*
output1
,
float
*
output2
,
float
*
gelu_in
,
void
*
lt_workspace
);
template
int
linear_gelu_linear_forward_cuda
<
double
>(
double
*
input
,
double
*
weight1
,
double
*
bias1
,
double
*
weight2
,
double
*
bias2
,
int
in_features
,
int
hidden_features
,
int
batch_size
,
int
out_features
,
double
*
output1
,
double
*
output2
,
double
*
gelu_in
,
void
*
lt_workspace
)
;
template
int
linear_gelu_linear_backward_cuda
<
at
::
Half
>(
at
::
Half
*
input
,
at
::
Half
*
gelu_in
,
at
::
Half
*
output1
,
at
::
Half
*
weight1
,
at
::
Half
*
weight2
,
at
::
Half
*
d_output1
,
at
::
Half
*
d_output2
,
int
in_features
,
int
batch_size
,
int
hidden_features
,
int
out_features
,
at
::
Half
*
d_weight1
,
at
::
Half
*
d_weight2
,
at
::
Half
*
d_bias1
,
at
::
Half
*
d_bias2
,
at
::
Half
*
d_input
,
void
*
lt_workspace
);
template
int
linear_gelu_linear_backward_cuda
<
float
>(
float
*
input
,
float
*
gelu_in
,
float
*
output1
,
float
*
weight1
,
float
*
weight2
,
float
*
d_output1
,
float
*
d_output2
,
int
in_features
,
int
batch_size
,
int
hidden_features
,
int
out_features
,
float
*
d_weight1
,
float
*
d_weight2
,
float
*
d_bias1
,
float
*
d_bias2
,
float
*
d_input
,
void
*
lt_workspace
);
template
int
linear_gelu_linear_backward_cuda
<
double
>(
double
*
input
,
double
*
gelu_in
,
double
*
output1
,
double
*
weight1
,
double
*
weight2
,
double
*
d_output1
,
double
*
d_output2
,
int
in_features
,
int
batch_size
,
int
hidden_features
,
int
out_features
,
double
*
d_weight1
,
double
*
d_weight2
,
double
*
d_bias1
,
double
*
d_bias2
,
double
*
d_input
,
void
*
lt_workspace
);
csrc/layer_norm_cuda.cpp
View file @
f79993d9
...
@@ -130,13 +130,13 @@ std::vector<at::Tensor> layer_norm(
...
@@ -130,13 +130,13 @@ std::vector<at::Tensor> layer_norm(
int
n1
,
n2
;
int
n1
,
n2
;
check_args
(
input
,
normalized_shape
,
n1
,
n2
);
check_args
(
input
,
normalized_shape
,
n1
,
n2
);
at
::
Tensor
output
=
at
::
empty_like
(
input
);
at
::
Tensor
output
=
at
::
empty_like
(
input
);
at
::
Tensor
mean
=
at
::
empty
({
n1
},
input
.
options
().
dtype
((
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
||
at
::
Tensor
mean
=
at
::
empty
({
n1
},
input
.
options
().
dtype
(
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
||
input
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
?
at
::
ScalarType
::
Float
:
input
.
scalar_type
()));
input
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
?
at
::
ScalarType
::
Float
:
input
.
scalar_type
()));
at
::
Tensor
invvar
=
at
::
empty_like
(
mean
);
at
::
Tensor
invvar
=
at
::
empty_like
(
mean
);
cuda_layer_norm
(
&
output
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
cuda_layer_norm
(
&
output
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
normalized_shape
,
NULL
,
NULL
,
epsilon
);
normalized_shape
,
NULL
,
NULL
,
epsilon
);
return
{
output
,
mean
,
invvar
};
return
{
output
,
mean
,
invvar
};
}
}
std
::
vector
<
at
::
Tensor
>
layer_norm_affine
(
std
::
vector
<
at
::
Tensor
>
layer_norm_affine
(
at
::
Tensor
input
,
at
::
Tensor
input
,
#ifdef VERSION_GE_1_1
#ifdef VERSION_GE_1_1
...
@@ -153,14 +153,35 @@ std::vector<at::Tensor> layer_norm_affine(
...
@@ -153,14 +153,35 @@ std::vector<at::Tensor> layer_norm_affine(
int
n1
,
n2
;
int
n1
,
n2
;
check_args
(
input
,
normalized_shape
,
gamma
,
beta
,
n1
,
n2
);
check_args
(
input
,
normalized_shape
,
gamma
,
beta
,
n1
,
n2
);
at
::
Tensor
output
=
at
::
empty_like
(
input
);
at
::
Tensor
output
=
at
::
empty_like
(
input
);
at
::
Tensor
mean
=
at
::
empty
({
n1
},
input
.
options
().
dtype
((
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
||
const
auto
stats_dtype
=
(
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
||
input
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
?
at
::
ScalarType
::
Float
:
input
.
scalar_type
();
input
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
?
at
::
ScalarType
::
Float
:
input
.
scalar_
type
()
));
at
::
Tensor
mean
=
at
::
empty
({
n1
},
input
.
options
().
dtype
(
stats_d
type
));
at
::
Tensor
invvar
=
at
::
empty_like
(
mean
);
at
::
Tensor
invvar
=
at
::
empty_like
(
mean
);
cuda_layer_norm
(
&
output
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
cuda_layer_norm
(
&
output
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
normalized_shape
,
&
gamma
,
&
beta
,
epsilon
);
normalized_shape
,
&
gamma
,
&
beta
,
epsilon
);
return
{
output
,
mean
,
invvar
};
return
{
output
,
mean
,
invvar
};
}
}
std
::
vector
<
at
::
Tensor
>
layer_norm_affine_mixed_dtypes
(
at
::
Tensor
input
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
at
::
Tensor
gamma
,
at
::
Tensor
beta
,
double
epsilon
)
{
CHECK_INPUT
(
input
);
int
n1
,
n2
;
check_args
(
input
,
normalized_shape
,
n1
,
n2
);
at
::
Tensor
output
=
at
::
empty_like
(
input
,
gamma
.
options
().
dtype
(
gamma
.
scalar_type
()));
at
::
Tensor
mean
=
at
::
empty
({
n1
},
input
.
options
().
dtype
(
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
||
input
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
?
at
::
ScalarType
::
Float
:
input
.
scalar_type
()));
at
::
Tensor
invvar
=
at
::
empty_like
(
mean
);
cuda_layer_norm
(
&
output
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
normalized_shape
,
&
gamma
,
&
beta
,
epsilon
);
return
{
output
,
mean
,
invvar
};
}
void
cuda_layer_norm_gradient
(
void
cuda_layer_norm_gradient
(
at
::
Tensor
*
dout
,
at
::
Tensor
*
dout
,
at
::
Tensor
*
mean
,
at
::
Tensor
*
mean
,
...
@@ -204,6 +225,7 @@ at::Tensor layer_norm_gradient(
...
@@ -204,6 +225,7 @@ at::Tensor layer_norm_gradient(
&
grad_input
,
NULL
,
NULL
);
&
grad_input
,
NULL
,
NULL
);
return
grad_input
;
return
grad_input
;
}
}
std
::
vector
<
at
::
Tensor
>
layer_norm_gradient_affine
(
std
::
vector
<
at
::
Tensor
>
layer_norm_gradient_affine
(
at
::
Tensor
dout
,
at
::
Tensor
dout
,
at
::
Tensor
mean
,
at
::
Tensor
mean
,
...
@@ -239,5 +261,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -239,5 +261,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"forward"
,
&
layer_norm
,
"LayerNorm forward (CUDA)"
);
m
.
def
(
"forward"
,
&
layer_norm
,
"LayerNorm forward (CUDA)"
);
m
.
def
(
"backward_affine"
,
&
layer_norm_gradient_affine
,
"LayerNorm backward (CUDA)"
);
m
.
def
(
"backward_affine"
,
&
layer_norm_gradient_affine
,
"LayerNorm backward (CUDA)"
);
m
.
def
(
"backward"
,
&
layer_norm_gradient
,
"LayerNorm backward (CUDA)"
);
m
.
def
(
"backward"
,
&
layer_norm_gradient
,
"LayerNorm backward (CUDA)"
);
m
.
def
(
"forward_affine_mixed_dtypes"
,
&
layer_norm_affine_mixed_dtypes
,
"LayerNorm forward with mixed dtypes (CUDA) compatible with Megatron's implementation"
);
}
}
csrc/layer_norm_cuda_kernel.cu
View file @
f79993d9
#include "ATen/ATen.h"
#include "ATen/ATen.h"
#include "ATen/AccumulateType.h"
#include "ATen/AccumulateType.h"
#include "ATen/cuda/CUDAContext.h"
#include "ATen/cuda/CUDAContext.h"
#include
<THC/THC
DeviceUtils.cuh
>
#include
"ATen/cuda/
DeviceUtils.cuh
"
#include <cuda.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_runtime.h>
...
@@ -56,7 +56,7 @@ void cuWelfordMuSigma2(
...
@@ -56,7 +56,7 @@ void cuWelfordMuSigma2(
const
int
i1
,
const
int
i1
,
U
&
mu
,
U
&
mu
,
U
&
sigma2
,
U
&
sigma2
,
U
*
buf
)
U
*
buf
)
{
{
// Assumptions:
// Assumptions:
// 1) blockDim.x == warpSize
// 1) blockDim.x == warpSize
...
@@ -140,7 +140,7 @@ void cuWelfordMuSigma2(
...
@@ -140,7 +140,7 @@ void cuWelfordMuSigma2(
const
int
i1
,
const
int
i1
,
float
&
mu
,
float
&
mu
,
float
&
sigma2
,
float
&
sigma2
,
float
*
buf
)
float
*
buf
)
{
{
// Assumptions:
// Assumptions:
// 1) blockDim.x == warpSize
// 1) blockDim.x == warpSize
...
@@ -172,8 +172,8 @@ void cuWelfordMuSigma2(
...
@@ -172,8 +172,8 @@ void cuWelfordMuSigma2(
for
(;
l
+
7
<
n2
;
l
+=
8
*
numx
)
{
for
(;
l
+
7
<
n2
;
l
+=
8
*
numx
)
{
for
(
int
k
=
0
;
k
<
8
;
k
+=
2
)
{
for
(
int
k
=
0
;
k
<
8
;
k
+=
2
)
{
float2
curr
=
__half22float2
(
*
((
__half2
*
)(
lvals
+
l
+
k
)));
float2
curr
=
__half22float2
(
*
((
__half2
*
)(
lvals
+
l
+
k
)));
cuWelfordOnlineSum
<
float
>
(
curr
.
x
,
mu
,
sigma2
,
count
);
cuWelfordOnlineSum
(
curr
.
x
,
mu
,
sigma2
,
count
);
cuWelfordOnlineSum
<
float
>
(
curr
.
y
,
mu
,
sigma2
,
count
);
cuWelfordOnlineSum
(
curr
.
y
,
mu
,
sigma2
,
count
);
}
}
}
}
for
(;
l
<
n2
;
++
l
)
{
for
(;
l
<
n2
;
++
l
)
{
...
@@ -282,18 +282,18 @@ struct SharedMemory <double>
...
@@ -282,18 +282,18 @@ struct SharedMemory <double>
};
};
}
}
template
<
typename
T
,
typename
U
>
__global
__
template
<
typename
T
,
typename
U
,
typename
V
>
__device
__
void
cuApplyLayerNorm
(
void
cuApplyLayerNorm
_
(
T
*
__restrict__
output_vals
,
V
*
__restrict__
output_vals
,
U
*
__restrict__
mean
,
U
*
__restrict__
mean
,
U
*
__restrict__
invvar
,
U
*
__restrict__
invvar
,
const
T
*
__restrict__
vals
,
const
T
*
__restrict__
vals
,
const
int
n1
,
const
int
n1
,
const
int
n2
,
const
int
n2
,
const
U
epsilon
,
const
U
epsilon
,
const
T
*
__restrict__
gamma
,
const
V
*
__restrict__
gamma
,
const
T
*
__restrict__
beta
const
V
*
__restrict__
beta
)
)
{
{
// Assumptions:
// Assumptions:
// 1) blockDim.x == warpSize
// 1) blockDim.x == warpSize
...
@@ -305,29 +305,47 @@ void cuApplyLayerNorm(
...
@@ -305,29 +305,47 @@ void cuApplyLayerNorm(
U
mu
,
sigma2
;
U
mu
,
sigma2
;
cuWelfordMuSigma2
(
vals
,
n1
,
n2
,
i1
,
mu
,
sigma2
,
buf
);
cuWelfordMuSigma2
(
vals
,
n1
,
n2
,
i1
,
mu
,
sigma2
,
buf
);
const
T
*
lvals
=
vals
+
i1
*
n2
;
const
T
*
lvals
=
vals
+
i1
*
n2
;
T
*
ovals
=
output_vals
+
i1
*
n2
;
V
*
ovals
=
output_vals
+
i1
*
n2
;
U
c_invvar
=
rsqrt
(
sigma2
+
epsilon
);
U
c_invvar
=
rsqrt
(
sigma2
+
epsilon
);
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
if
(
gamma
!=
NULL
&&
beta
!=
NULL
)
{
if
(
gamma
!=
NULL
&&
beta
!=
NULL
)
{
for
(
int
i
=
thrx
;
i
<
n2
;
i
+=
numx
)
{
for
(
int
i
=
thrx
;
i
<
n2
;
i
+=
numx
)
{
U
curr
=
static_cast
<
U
>
(
lvals
[
i
]);
U
curr
=
static_cast
<
U
>
(
lvals
[
i
]);
ovals
[
i
]
=
gamma
[
i
]
*
static_cast
<
T
>
(
c_invvar
*
(
curr
-
mu
))
+
beta
[
i
];
ovals
[
i
]
=
gamma
[
i
]
*
static_cast
<
V
>
(
c_invvar
*
(
curr
-
mu
))
+
beta
[
i
];
}
}
}
else
{
}
else
{
for
(
int
i
=
thrx
;
i
<
n2
;
i
+=
numx
)
{
for
(
int
i
=
thrx
;
i
<
n2
;
i
+=
numx
)
{
U
curr
=
static_cast
<
U
>
(
lvals
[
i
]);
U
curr
=
static_cast
<
U
>
(
lvals
[
i
]);
ovals
[
i
]
=
static_cast
<
T
>
(
c_invvar
*
(
curr
-
mu
));
ovals
[
i
]
=
static_cast
<
V
>
(
c_invvar
*
(
curr
-
mu
));
}
}
}
}
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
==
0
)
{
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
==
0
)
{
mean
[
i1
]
=
mu
;
mean
[
i1
]
=
mu
;
invvar
[
i1
]
=
c_invvar
;
invvar
[
i1
]
=
c_invvar
;
}
}
__syncthreads
();
}
}
}
}
template
<
typename
T
,
typename
U
>
__device__
template
<
typename
T
,
typename
U
,
typename
V
=
T
>
__global__
void
cuApplyLayerNorm
(
V
*
__restrict__
output_vals
,
U
*
__restrict__
mean
,
U
*
__restrict__
invvar
,
const
T
*
__restrict__
vals
,
const
int
n1
,
const
int
n2
,
const
U
epsilon
,
const
V
*
__restrict__
gamma
,
const
V
*
__restrict__
beta
)
{
cuApplyLayerNorm_
<
T
,
U
,
V
>
(
output_vals
,
mean
,
invvar
,
vals
,
n1
,
n2
,
epsilon
,
gamma
,
beta
);
}
template
<
typename
T
,
typename
U
,
typename
V
>
__device__
void
cuLoadWriteStridedInputs
(
void
cuLoadWriteStridedInputs
(
const
int
i1_block
,
const
int
i1_block
,
const
int
thr_load_row_off
,
const
int
thr_load_row_off
,
...
@@ -337,7 +355,7 @@ void cuLoadWriteStridedInputs(
...
@@ -337,7 +355,7 @@ void cuLoadWriteStridedInputs(
U
*
warp_buf1
,
U
*
warp_buf1
,
U
*
warp_buf2
,
U
*
warp_buf2
,
const
T
*
input
,
const
T
*
input
,
const
T
*
dout
,
const
V
*
dout
,
const
int
i1_end
,
const
int
i1_end
,
const
int
n2
,
const
int
n2
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
mean
,
...
@@ -354,9 +372,9 @@ void cuLoadWriteStridedInputs(
...
@@ -354,9 +372,9 @@ void cuLoadWriteStridedInputs(
int
write_idx
=
thr_load_row_off
*
row_stride
+
thr_load_col_off
+
k
;
int
write_idx
=
thr_load_row_off
*
row_stride
+
thr_load_col_off
+
k
;
if
(
i2
<
n2
)
{
if
(
i2
<
n2
)
{
U
curr_input
=
static_cast
<
U
>
(
input
[
load_idx
]);
U
curr_input
=
static_cast
<
U
>
(
input
[
load_idx
]);
U
curr_dout
=
static_cast
<
U
>
(
dout
[
load_idx
]);
U
curr_dout
=
static_cast
<
U
>
(
dout
[
load_idx
]);
warp_buf1
[
write_idx
]
=
curr_dout
;
warp_buf1
[
write_idx
]
=
curr_dout
;
warp_buf2
[
write_idx
]
=
curr_dout
*
(
curr_input
-
curr_mean
)
*
curr_invvar
;
warp_buf2
[
write_idx
]
=
curr_dout
*
(
curr_input
-
curr_mean
)
*
curr_invvar
;
}
else
{
}
else
{
warp_buf1
[
write_idx
]
=
U
(
0
);
warp_buf1
[
write_idx
]
=
U
(
0
);
warp_buf2
[
write_idx
]
=
U
(
0
);
warp_buf2
[
write_idx
]
=
U
(
0
);
...
@@ -371,7 +389,7 @@ void cuLoadWriteStridedInputs(
...
@@ -371,7 +389,7 @@ void cuLoadWriteStridedInputs(
}
}
}
}
template
<
typename
T
,
typename
U
>
__device__
template
<
typename
T
,
typename
U
,
typename
V
>
__device__
void
cuLoadAddStridedInputs
(
void
cuLoadAddStridedInputs
(
const
int
i1_block
,
const
int
i1_block
,
const
int
thr_load_row_off
,
const
int
thr_load_row_off
,
...
@@ -381,7 +399,7 @@ void cuLoadAddStridedInputs(
...
@@ -381,7 +399,7 @@ void cuLoadAddStridedInputs(
U
*
warp_buf1
,
U
*
warp_buf1
,
U
*
warp_buf2
,
U
*
warp_buf2
,
const
T
*
input
,
const
T
*
input
,
const
T
*
dout
,
const
V
*
dout
,
const
int
i1_end
,
const
int
i1_end
,
const
int
n2
,
const
int
n2
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
mean
,
...
@@ -398,17 +416,17 @@ void cuLoadAddStridedInputs(
...
@@ -398,17 +416,17 @@ void cuLoadAddStridedInputs(
int
write_idx
=
thr_load_row_off
*
row_stride
+
thr_load_col_off
+
k
;
int
write_idx
=
thr_load_row_off
*
row_stride
+
thr_load_col_off
+
k
;
if
(
i2
<
n2
)
{
if
(
i2
<
n2
)
{
U
curr_input
=
static_cast
<
U
>
(
input
[
load_idx
]);
U
curr_input
=
static_cast
<
U
>
(
input
[
load_idx
]);
U
curr_dout
=
static_cast
<
U
>
(
dout
[
load_idx
]);
U
curr_dout
=
static_cast
<
U
>
(
dout
[
load_idx
]);
warp_buf1
[
write_idx
]
+=
curr_dout
;
warp_buf1
[
write_idx
]
+=
curr_dout
;
warp_buf2
[
write_idx
]
+=
curr_dout
*
(
curr_input
-
curr_mean
)
*
curr_invvar
;
warp_buf2
[
write_idx
]
+=
curr_dout
*
(
curr_input
-
curr_mean
)
*
curr_invvar
;
}
}
}
}
}
}
}
}
template
<
typename
T
,
typename
U
>
__global__
template
<
typename
T
,
typename
U
,
typename
V
>
__global__
void
cuComputePartGradGammaBeta
(
void
cuComputePartGradGammaBeta
(
const
T
*
__restrict__
dout
,
const
V
*
__restrict__
dout
,
const
T
*
__restrict__
input
,
const
T
*
__restrict__
input
,
const
int
n1
,
const
int
n1
,
const
int
n2
,
const
int
n2
,
...
@@ -455,11 +473,11 @@ void cuComputePartGradGammaBeta(
...
@@ -455,11 +473,11 @@ void cuComputePartGradGammaBeta(
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
1
;
offset
/=
2
)
{
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
1
;
offset
/=
2
)
{
if
(
threadIdx
.
y
<
offset
)
{
if
(
threadIdx
.
y
<
offset
)
{
int
row1
=
threadIdx
.
y
;
int
row1
=
threadIdx
.
y
;
int
row2
=
threadIdx
.
y
+
offset
;
int
row2
=
threadIdx
.
y
+
offset
;
int
idx1
=
row1
*
row_stride
+
threadIdx
.
x
;
int
idx1
=
row1
*
row_stride
+
threadIdx
.
x
;
int
idx2
=
row2
*
row_stride
+
threadIdx
.
x
;
int
idx2
=
row2
*
row_stride
+
threadIdx
.
x
;
warp_buf1
[
idx1
]
+=
warp_buf1
[
idx2
];
warp_buf1
[
idx1
]
+=
warp_buf1
[
idx2
];
warp_buf2
[
idx1
]
+=
warp_buf2
[
idx2
];
warp_buf2
[
idx1
]
+=
warp_buf2
[
idx2
];
}
}
__syncthreads
();
__syncthreads
();
}
}
...
@@ -474,19 +492,19 @@ void cuComputePartGradGammaBeta(
...
@@ -474,19 +492,19 @@ void cuComputePartGradGammaBeta(
}
}
}
}
template
<
typename
T
,
typename
U
>
__global__
template
<
typename
U
,
typename
V
>
__global__
void
cuComputeGradGammaBeta
(
void
cuComputeGradGammaBeta
(
const
U
*
part_grad_gamma
,
const
U
*
part_grad_gamma
,
const
U
*
part_grad_beta
,
const
U
*
part_grad_beta
,
const
int
part_size
,
const
int
part_size
,
const
int
n1
,
const
int
n1
,
const
int
n2
,
const
int
n2
,
T
*
grad_gamma
,
V
*
grad_gamma
,
T
*
grad_beta
)
V
*
grad_beta
)
{
{
// sum partial gradients for gamma and beta
// sum partial gradients for gamma and beta
SharedMemory
<
U
>
shared
;
SharedMemory
<
U
>
shared
;
U
*
buf
=
shared
.
getPointer
();
U
*
buf
=
shared
.
getPointer
();
int
i2
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
i2
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i2
<
n2
)
{
if
(
i2
<
n2
)
{
// each warp does sequential reductions until reduced part_size is num_warps
// each warp does sequential reductions until reduced part_size is num_warps
...
@@ -525,16 +543,16 @@ void cuComputeGradGammaBeta(
...
@@ -525,16 +543,16 @@ void cuComputeGradGammaBeta(
}
}
}
}
template
<
typename
T
,
typename
U
>
__global__
template
<
typename
T
,
typename
U
,
typename
V
>
__global__
void
cuComputeGradInput
(
void
cuComputeGradInput
(
const
T
*
__restrict__
dout
,
const
V
*
__restrict__
dout
,
const
T
*
__restrict__
input
,
const
T
*
__restrict__
input
,
const
int
n1
,
const
int
n1
,
const
int
n2
,
const
int
n2
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
invvar
,
const
U
*
__restrict__
invvar
,
U
epsilon
,
U
epsilon
,
const
T
*
gamma
,
const
V
*
gamma
,
T
*
grad_input
)
T
*
grad_input
)
{
{
for
(
int
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
for
(
int
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
...
@@ -543,7 +561,7 @@ void cuComputeGradInput(
...
@@ -543,7 +561,7 @@ void cuComputeGradInput(
const
U
c_mean
=
mean
[
i1
];
const
U
c_mean
=
mean
[
i1
];
const
U
c_invvar
=
invvar
[
i1
];
const
U
c_invvar
=
invvar
[
i1
];
const
T
*
k_input
=
input
+
i1
*
n2
;
const
T
*
k_input
=
input
+
i1
*
n2
;
const
T
*
k_dout
=
dout
+
i1
*
n2
;
const
V
*
k_dout
=
dout
+
i1
*
n2
;
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
if
(
gamma
!=
NULL
)
{
if
(
gamma
!=
NULL
)
{
...
@@ -587,7 +605,7 @@ void cuComputeGradInput(
...
@@ -587,7 +605,7 @@ void cuComputeGradInput(
// inter-warp reductions
// inter-warp reductions
if
(
blockDim
.
y
>
1
)
{
if
(
blockDim
.
y
>
1
)
{
SharedMemory
<
U
>
shared
;
SharedMemory
<
U
>
shared
;
U
*
buf
=
shared
.
getPointer
();
U
*
buf
=
shared
.
getPointer
();
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
0
;
offset
/=
2
)
{
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
0
;
offset
/=
2
)
{
// upper half of warps write to shared
// upper half of warps write to shared
if
(
threadIdx
.
y
>=
offset
&&
threadIdx
.
y
<
2
*
offset
)
{
if
(
threadIdx
.
y
>=
offset
&&
threadIdx
.
y
<
2
*
offset
)
{
...
@@ -612,7 +630,7 @@ void cuComputeGradInput(
...
@@ -612,7 +630,7 @@ void cuComputeGradInput(
if
(
threadIdx
.
y
!=
0
)
{
if
(
threadIdx
.
y
!=
0
)
{
sum_loss1
=
buf
[
2
*
threadIdx
.
x
];
sum_loss1
=
buf
[
2
*
threadIdx
.
x
];
sum_loss2
=
buf
[
2
*
threadIdx
.
x
+
1
];
sum_loss2
=
buf
[
2
*
threadIdx
.
x
+
1
];
}
}
}
}
// all threads now have the two sums over l
// all threads now have the two sums over l
U
fH
=
(
U
)
n2
;
U
fH
=
(
U
)
n2
;
...
@@ -639,38 +657,34 @@ void cuComputeGradInput(
...
@@ -639,38 +657,34 @@ void cuComputeGradInput(
k_grad_input
[
l
]
=
static_cast
<
T
>
(
f_grad_input
);
k_grad_input
[
l
]
=
static_cast
<
T
>
(
f_grad_input
);
}
}
}
}
// prevent race where buf is written again before reads are done
__syncthreads
();
}
}
}
}
template
<
typename
T
,
typename
U
>
template
<
typename
T
,
typename
U
,
typename
V
=
T
>
void
HostApplyLayerNorm
(
void
HostApplyLayerNorm
(
T
*
output
,
V
*
output
,
U
*
mean
,
U
*
mean
,
U
*
invvar
,
U
*
invvar
,
const
T
*
input
,
const
T
*
input
,
int
n1
,
int
n1
,
int
n2
,
int
n2
,
double
epsilon
,
double
epsilon
,
const
T
*
gamma
,
const
V
*
gamma
,
const
T
*
beta
const
V
*
beta
)
)
{
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
dim3
threads
(
32
,
4
,
1
);
const
dim3
threads
(
32
,
4
,
1
);
const
uint64_t
maxGridY
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxGridSize
[
1
];
const
uint64_t
maxGridY
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxGridSize
[
1
];
const
dim3
blocks
(
1
,
std
::
min
((
uint64_t
)
n1
,
maxGridY
),
1
);
const
dim3
blocks
(
1
,
std
::
min
((
uint64_t
)
n1
,
maxGridY
),
1
);
int
nshared
=
int
nshared
=
threads
.
y
>
1
?
threads
.
y
>
1
?
threads
.
y
*
sizeof
(
U
)
+
(
threads
.
y
/
2
)
*
sizeof
(
U
)
:
threads
.
y
*
sizeof
(
U
)
+
(
threads
.
y
/
2
)
*
sizeof
(
U
)
:
0
;
0
;
cuApplyLayerNorm
<<<
blocks
,
threads
,
nshared
,
stream
>>>
(
cuApplyLayerNorm
<<<
blocks
,
threads
,
nshared
,
stream
>>>
(
output
,
output
,
mean
,
invvar
,
input
,
n1
,
n2
,
U
(
epsilon
),
gamma
,
beta
);
mean
,
invvar
,
input
,
n1
,
n2
,
U
(
epsilon
),
gamma
,
beta
);
}
}
void
cuda_layer_norm
(
void
cuda_layer_norm
(
...
@@ -690,34 +704,35 @@ void cuda_layer_norm(
...
@@ -690,34 +704,35 @@ void cuda_layer_norm(
double
epsilon
)
double
epsilon
)
{
{
using
namespace
at
;
using
namespace
at
;
DISPATCH_DOUBLE_FLOAT_AND_HALF_AND_BFLOAT16
(
input
->
scalar_type
(),
0
,
"layer_norm_cuda_kernel"
,
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES
(
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
input
->
scalar_type
(),
output
->
scalar_type
(),
"layer_norm_cuda_kernel"
,
HostApplyLayerNorm
(
using
accscalar_t
=
at
::
acc_type
<
scalar_t_in
,
true
>
;
output
->
DATA_PTR
<
scalar_t_0
>
(),
HostApplyLayerNorm
<
scalar_t_in
,
accscalar_t
,
scalar_t_out
>
(
mean
->
DATA_PTR
<
accscalar_t
>
(),
output
->
DATA_PTR
<
scalar_t_out
>
(),
invvar
->
DATA_PTR
<
accscalar_t
>
(),
mean
->
DATA_PTR
<
accscalar_t
>
(),
input
->
DATA_PTR
<
scalar_t_0
>
(),
invvar
->
DATA_PTR
<
accscalar_t
>
(),
n1
,
n2
,
input
->
DATA_PTR
<
scalar_t_in
>
(),
epsilon
,
n1
,
n2
,
gamma
!=
NULL
?
gamma
->
DATA_PTR
<
scalar_t_0
>
()
:
NULL
,
epsilon
,
beta
!=
NULL
?
beta
->
DATA_PTR
<
scalar_t_0
>
()
:
NULL
);
gamma
!=
NULL
?
gamma
->
DATA_PTR
<
scalar_t_out
>
()
:
NULL
,
beta
!=
NULL
?
beta
->
DATA_PTR
<
scalar_t_out
>
()
:
NULL
);
)
)
}
}
template
<
typename
T
,
typename
U
>
template
<
typename
T
,
typename
U
=
float
,
typename
V
=
T
>
void
HostLayerNormGradient
(
void
HostLayerNormGradient
(
const
T
*
dout
,
const
V
*
dout
,
const
U
*
mean
,
const
U
*
mean
,
const
U
*
invvar
,
const
U
*
invvar
,
at
::
Tensor
*
input
,
at
::
Tensor
*
input
,
int
n1
,
int
n1
,
int
n2
,
int
n2
,
const
T
*
gamma
,
const
V
*
gamma
,
const
T
*
beta
,
const
V
*
beta
,
double
epsilon
,
double
epsilon
,
T
*
grad_input
,
T
*
grad_input
,
T
*
grad_gamma
,
V
*
grad_gamma
,
T
*
grad_beta
V
*
grad_beta
)
)
{
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
...
@@ -730,8 +745,13 @@ void HostLayerNormGradient(
...
@@ -730,8 +745,13 @@ void HostLayerNormGradient(
const
int
nshared2_a
=
2
*
sizeof
(
U
)
*
threads2
.
y
*
threads2
.
y
*
(
threads2
.
x
+
1
);
const
int
nshared2_a
=
2
*
sizeof
(
U
)
*
threads2
.
y
*
threads2
.
y
*
(
threads2
.
x
+
1
);
const
int
nshared2_b
=
threads2
.
x
*
threads2
.
y
*
sizeof
(
U
);
const
int
nshared2_b
=
threads2
.
x
*
threads2
.
y
*
sizeof
(
U
);
const
int
nshared2
=
nshared2_a
>
nshared2_b
?
nshared2_a
:
nshared2_b
;
const
int
nshared2
=
nshared2_a
>
nshared2_b
?
nshared2_a
:
nshared2_b
;
at
::
Tensor
part_grad_gamma
=
at
::
empty
({
part_size
,
n2
},
input
->
options
().
dtype
((
input
->
scalar_type
()
==
at
::
ScalarType
::
Half
||
// note (mkozuki): I can hard code part_grad_gamma's dtype as float given that
input
->
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
?
at
::
ScalarType
::
Float
:
input
->
scalar_type
()));
// the `cuda_layer_norm_gradient` doesn't support double.
const
auto
part_grad_dtype
=
(
input
->
scalar_type
()
==
at
::
ScalarType
::
Half
||
input
->
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
?
at
::
ScalarType
::
Float
:
input
->
scalar_type
();
at
::
Tensor
part_grad_gamma
=
at
::
empty
({
part_size
,
n2
},
input
->
options
().
dtype
(
part_grad_dtype
));
at
::
Tensor
part_grad_beta
=
at
::
empty_like
(
part_grad_gamma
);
at
::
Tensor
part_grad_beta
=
at
::
empty_like
(
part_grad_gamma
);
cuComputePartGradGammaBeta
<<<
blocks2
,
threads2
,
nshared2
,
stream
>>>
(
cuComputePartGradGammaBeta
<<<
blocks2
,
threads2
,
nshared2
,
stream
>>>
(
dout
,
dout
,
...
@@ -794,21 +814,23 @@ void cuda_layer_norm_gradient(
...
@@ -794,21 +814,23 @@ void cuda_layer_norm_gradient(
at
::
Tensor
*
grad_beta
)
at
::
Tensor
*
grad_beta
)
{
{
using
namespace
at
;
using
namespace
at
;
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16
(
input
->
scalar_type
(),
0
,
"cuComputeGradInput"
,
// we can do away with `accscalar_t` as there're only three dtypes: fp32, fp16, bf16
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES
(
HostLayerNormGradient
(
input
->
scalar_type
(),
gamma
==
NULL
?
input
->
scalar_type
()
:
gamma
->
scalar_type
(),
"cuComputeGradInput"
,
dout
->
DATA_PTR
<
scalar_t_0
>
(),
using
accscalar_t
=
at
::
acc_type
<
scalar_t_in
,
true
>
;
mean
->
DATA_PTR
<
accscalar_t
>
(),
HostLayerNormGradient
(
invvar
->
DATA_PTR
<
accscalar_t
>
(),
dout
->
DATA_PTR
<
scalar_t_out
>
(),
input
,
mean
->
DATA_PTR
<
accscalar_t
>
(),
n1
,
n2
,
invvar
->
DATA_PTR
<
accscalar_t
>
(),
input
,
n1
,
n2
,
// TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
// TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
// if gamma Tensor is NULL on input.
// if gamma Tensor is NULL on input.
gamma
!=
NULL
?
gamma
->
DATA_PTR
<
scalar_t_
0
>
()
:
NULL
,
gamma
!=
NULL
?
gamma
->
DATA_PTR
<
scalar_t_
out
>
()
:
NULL
,
gamma
!=
NULL
?
beta
->
DATA_PTR
<
scalar_t_
0
>
()
:
NULL
,
gamma
!=
NULL
?
beta
->
DATA_PTR
<
scalar_t_
out
>
()
:
NULL
,
epsilon
,
epsilon
,
grad_input
->
DATA_PTR
<
scalar_t_
0
>
(),
grad_input
->
DATA_PTR
<
scalar_t_
in
>
(),
gamma
!=
NULL
?
grad_gamma
->
DATA_PTR
<
scalar_t_
0
>
()
:
NULL
,
gamma
!=
NULL
?
grad_gamma
->
DATA_PTR
<
scalar_t_
out
>
()
:
NULL
,
gamma
!=
NULL
?
grad_beta
->
DATA_PTR
<
scalar_t_
0
>
()
:
NULL
);
gamma
!=
NULL
?
grad_beta
->
DATA_PTR
<
scalar_t_
out
>
()
:
NULL
);
)
)
}
}
csrc/megatron/scaled_masked_softmax.cpp
0 → 100644
View file @
f79993d9
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace
multihead_attn
{
namespace
fused_softmax
{
namespace
scaled_masked_softmax
{
torch
::
Tensor
fwd_cuda
(
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
mask
,
float
scale_factor
);
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
softmax_results
,
float
scale_factor
);
int
get_batch_per_block_cuda
(
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
);
torch
::
Tensor
fwd
(
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
mask
,
float
scale_factor
)
{
AT_ASSERTM
(
input
.
dim
()
==
4
,
"expected 4D tensor"
);
AT_ASSERTM
((
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
input
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
AT_ASSERTM
(
mask
.
dim
()
==
4
,
"expected 4D tensor"
);
return
fwd_cuda
(
input
,
mask
,
scale_factor
);
}
torch
::
Tensor
bwd
(
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
softmax_results
,
float
scale_factor
)
{
AT_ASSERTM
(
output_grads
.
dim
()
==
4
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
4
,
"expected 3D tensor"
);
AT_ASSERTM
((
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
AT_ASSERTM
((
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
return
bwd_cuda
(
output_grads
,
softmax_results
,
scale_factor
);
}
int
get_batch_per_block
(
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
)
{
return
get_batch_per_block_cuda
(
query_seq_len
,
key_seq_len
,
batches
,
attn_heads
);
}
}
// end namespace scaled_masked_softmax
}
// end namespace fused_softmax
}
// end namespace multihead_attn
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
multihead_attn
::
fused_softmax
::
scaled_masked_softmax
::
fwd
,
"Self Multihead Attention scaled, time masked softmax -- Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
fused_softmax
::
scaled_masked_softmax
::
bwd
,
"Self Multihead Attention scaled, time masked softmax -- Backward."
);
m
.
def
(
"get_batch_per_block"
,
&
multihead_attn
::
fused_softmax
::
scaled_masked_softmax
::
get_batch_per_block
,
"Return Batch per block size."
);
}
csrc/megatron/scaled_masked_softmax.h
0 → 100644
View file @
f79993d9
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <assert.h>
#include <cuda_fp16.h>
#include <cfloat>
#include <limits>
#include <stdint.h>
#include <cuda_fp16.h>
#include <c10/macros/Macros.h>
namespace
{
template
<
typename
Datatype
,
int
ELEMENTS_PER_LDG
>
__device__
__inline__
void
copy_vector
(
Datatype
*
dst
,
const
Datatype
*
src
);
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
BFloat16
,
1
>
(
c10
::
BFloat16
*
dst
,
const
c10
::
BFloat16
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
BFloat16
,
4
>
(
c10
::
BFloat16
*
dst
,
const
c10
::
BFloat16
*
src
)
{
*
((
float2
*
)
dst
)
=
*
((
float2
*
)
src
);
}
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
Half
,
1
>
(
c10
::
Half
*
dst
,
const
c10
::
Half
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
Half
,
4
>
(
c10
::
Half
*
dst
,
const
c10
::
Half
*
src
)
{
*
((
float2
*
)
dst
)
=
*
((
float2
*
)
src
);
}
template
<
>
__device__
__inline__
void
copy_vector
<
uint8_t
,
1
>
(
uint8_t
*
dst
,
const
uint8_t
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
uint8_t
,
4
>
(
uint8_t
*
dst
,
const
uint8_t
*
src
)
{
*
((
half2
*
)
dst
)
=
*
((
half2
*
)
src
);
}
int
log2_ceil
(
int
value
)
{
int
log2_value
=
0
;
while
((
1
<<
log2_value
)
<
value
)
++
log2_value
;
return
log2_value
;
}
template
<
typename
T
>
struct
Add
{
__device__
__forceinline__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
+
b
;
}
};
template
<
typename
T
>
struct
Max
{
__device__
__forceinline__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
<
b
?
b
:
a
;
}
};
template
<
typename
T
>
__device__
__forceinline__
T
WARP_SHFL_XOR_NATIVE
(
T
value
,
int
laneMask
,
int
width
=
warpSize
,
unsigned
int
mask
=
0xffffffff
)
{
#if CUDA_VERSION >= 9000
return
__shfl_xor_sync
(
mask
,
value
,
laneMask
,
width
);
#else
return
__shfl_xor
(
value
,
laneMask
,
width
);
#endif
}
template
<
typename
acc_t
,
int
WARP_BATCH
,
int
WARP_SIZE
,
template
<
typename
>
class
ReduceOp
>
__device__
__forceinline__
void
warp_reduce
(
acc_t
*
sum
)
{
ReduceOp
<
acc_t
>
r
;
#pragma unroll
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
acc_t
b
=
WARP_SHFL_XOR_NATIVE
(
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
=
r
(
sum
[
i
],
b
);
}
}
}
/*
* Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling
* 2) Explicit masking
*/
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
log2_elements
>
__global__
void
scaled_masked_softmax_warp_forward
(
output_t
*
dst
,
const
input_t
*
src
,
const
uint8_t
*
mask
,
const
acc_t
scale
,
int
micro_batch_size
,
int
element_count
,
int
pad_batches
)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_forward_kernel.
constexpr
int
next_power_of_two
=
1
<<
log2_elements
;
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
ELEMENTS_PER_LDG_STG
=
(
WARP_ITERATIONS
<
4
)
?
1
:
4
;
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
int
first_batch
=
(
blockDim
.
y
*
(
blockIdx
.
x
+
gridDim
.
x
*
(
blockIdx
.
y
+
gridDim
.
y
*
blockIdx
.
z
))
+
threadIdx
.
y
)
*
WARP_BATCH
;
int
pad_first_batch
=
0
;
if
(
pad_batches
!=
1
)
{
// bert style
pad_first_batch
=
(
blockDim
.
y
*
(
blockIdx
.
x
+
gridDim
.
x
*
blockIdx
.
z
)
+
threadIdx
.
y
)
*
WARP_BATCH
;
}
else
{
// gpt2 style
pad_first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
WARP_BATCH
;
}
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int
local_batches
=
micro_batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
local_batches
=
WARP_BATCH
;
// there might be multiple batches per warp. compute the index within the batch
int
local_idx
=
threadIdx
.
x
;
src
+=
first_batch
*
element_count
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
dst
+=
first_batch
*
element_count
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
mask
+=
pad_first_batch
*
element_count
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
// load data from global memory
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
input_t
temp_data
[
ELEMENTS_PER_LDG_STG
];
uint8_t
temp_mask
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
int
itr_idx
=
i
*
element_count
+
it
*
WARP_SIZE
;
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_data
,
src
+
itr_idx
);
copy_vector
<
uint8_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_mask
,
mask
+
itr_idx
);
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
if
(
temp_mask
[
element
]
!=
1
)
{
elements
[
i
][
it
+
element
]
=
(
acc_t
)
temp_data
[
element
]
*
scale
;
}
else
{
elements
[
i
][
it
+
element
]
=
-
10000.0
;
}
}
}
else
{
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
elements
[
i
][
it
+
element
]
=
-
std
::
numeric_limits
<
acc_t
>::
infinity
();
}
}
}
}
// compute max_value
acc_t
max_value
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
max_value
[
i
]
=
elements
[
i
][
0
];
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
max_value
[
i
]
=
(
max_value
[
i
]
>
elements
[
i
][
it
])
?
max_value
[
i
]
:
elements
[
i
][
it
];
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Max
>
(
max_value
);
acc_t
sum
[
WARP_BATCH
]
{
0.0
f
};
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
elements
[
i
][
it
]
=
std
::
exp
((
elements
[
i
][
it
]
-
max_value
[
i
]));
sum
[
i
]
+=
elements
[
i
][
it
];
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Add
>
(
sum
);
// store result
output_t
out
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
out
[
element
]
=
elements
[
i
][
it
+
element
]
/
sum
[
i
];
}
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
dst
+
i
*
element_count
+
it
*
WARP_SIZE
,
out
);
}
else
{
break
;
}
}
}
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
log2_elements
>
__global__
void
scaled_masked_softmax_warp_backward
(
output_t
*
gradInput
,
input_t
*
grad
,
const
input_t
*
output
,
acc_t
scale
,
int
micro_batch_size
,
int
element_count
)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_backward_kernel.
constexpr
int
next_power_of_two
=
1
<<
log2_elements
;
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
ELEMENTS_PER_LDG_STG
=
(
WARP_ITERATIONS
<
4
)
?
1
:
4
;
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
WARP_BATCH
;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int
local_batches
=
micro_batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
local_batches
=
WARP_BATCH
;
// there might be multiple batches per warp. compute the index within the batch
int
local_idx
=
threadIdx
.
x
;
// the first element to process by the current thread
int
thread_offset
=
first_batch
*
element_count
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
grad
+=
thread_offset
;
output
+=
thread_offset
;
gradInput
+=
thread_offset
;
// load data from global memory
acc_t
grad_reg
[
WARP_BATCH
][
WARP_ITERATIONS
]
{
0.0
f
};
acc_t
output_reg
[
WARP_BATCH
][
WARP_ITERATIONS
]
{
0.0
f
};
input_t
temp_grad
[
ELEMENTS_PER_LDG_STG
];
input_t
temp_output
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_grad
,
grad
+
i
*
element_count
+
it
*
WARP_SIZE
);
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_output
,
output
+
i
*
element_count
+
it
*
WARP_SIZE
);
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
output_reg
[
i
][
it
+
element
]
=
(
acc_t
)
temp_output
[
element
];
}
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
grad_reg
[
i
][
it
+
element
]
=
(
acc_t
)
temp_grad
[
element
]
*
output_reg
[
i
][
it
+
element
];
}
}
}
}
acc_t
sum
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
=
grad_reg
[
i
][
0
];
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
sum
[
i
]
+=
grad_reg
[
i
][
it
];
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Add
>
(
sum
);
// store result
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
// compute gradients
output_t
out
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
out
[
element
]
=
(
output_t
)(
scale
*
(
grad_reg
[
i
][
it
+
element
]
-
output_reg
[
i
][
it
+
element
]
*
sum
[
i
]));
}
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
gradInput
+
i
*
element_count
+
it
*
WARP_SIZE
,
out
);
}
}
}
}
}
// end of anonymous namespace
int
get_batch_per_block
(
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
){
int
log2_elements
=
log2_ceil
(
key_seq_len
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
return
batches_per_block
;
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
void
dispatch_scaled_masked_softmax_forward
(
output_t
*
dst
,
const
input_t
*
src
,
const
uint8_t
*
mask
,
const
input_t
scale
,
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
,
int
pad_batches
)
{
TORCH_INTERNAL_ASSERT
(
key_seq_len
>=
0
&&
key_seq_len
<=
2048
);
if
(
key_seq_len
==
0
)
{
return
;
}
else
{
int
log2_elements
=
log2_ceil
(
key_seq_len
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
batch_count
=
batches
*
attn_heads
*
query_seq_len
;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
TORCH_INTERNAL_ASSERT
(
query_seq_len
%
batches_per_block
==
0
);
dim3
blocks
(
query_seq_len
/
batches_per_block
,
attn_heads
,
batches
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch
(
log2_elements
)
{
case
0
:
// 1
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
0
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
1
:
// 2
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
1
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
2
:
// 4
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
3
:
// 8
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
3
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
4
:
// 16
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
4
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
5
:
// 32
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
5
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
6
:
// 64
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
6
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
7
:
// 128
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
7
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
8
:
// 256
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
8
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
9
:
// 512
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
9
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
10
:
// 1024
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
10
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
11
:
// 2048
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
11
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
default:
break
;
}
}
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
void
dispatch_scaled_masked_softmax_backward
(
output_t
*
grad_input
,
input_t
*
grad
,
const
input_t
*
output
,
const
acc_t
scale
,
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
)
{
TORCH_INTERNAL_ASSERT
(
key_seq_len
>=
0
&&
key_seq_len
<=
2048
);
if
(
key_seq_len
==
0
)
{
return
;
}
else
{
int
log2_elements
=
log2_ceil
(
key_seq_len
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
batch_count
=
batches
*
attn_heads
*
query_seq_len
;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
blocks
=
batch_count
/
batches_per_block
;
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch
(
log2_elements
)
{
case
0
:
// 1
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
0
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
1
:
// 2
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
1
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
2
:
// 4
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
3
:
// 8
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
3
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
4
:
// 16
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
4
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
5
:
// 32
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
5
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
6
:
// 64
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
6
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
7
:
// 128
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
7
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
8
:
// 256
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
8
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
9
:
// 512
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
9
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
10
:
// 1024
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
10
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
11
:
// 2048
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
11
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
default:
break
;
}
}
}
csrc/megatron/scaled_masked_softmax_cuda.cu
0 → 100644
View file @
f79993d9
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "scaled_masked_softmax.h"
#include "type_shim.h"
namespace
multihead_attn
{
namespace
fused_softmax
{
namespace
scaled_masked_softmax
{
int
get_batch_per_block_cuda
(
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
){
return
get_batch_per_block
(
query_seq_len
,
key_seq_len
,
batches
,
attn_heads
);
}
torch
::
Tensor
fwd_cuda
(
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
mask
,
float
scale_factor
)
{
// input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const
int
batches
=
input
.
size
(
0
);
const
int
pad_batches
=
mask
.
size
(
0
);
const
int
attn_heads
=
input
.
size
(
1
);
const
int
query_seq_len
=
input
.
size
(
2
);
const
int
key_seq_len
=
input
.
size
(
3
);
TORCH_INTERNAL_ASSERT
(
key_seq_len
<=
2048
);
TORCH_INTERNAL_ASSERT
(
query_seq_len
>
1
);
TORCH_INTERNAL_ASSERT
(
pad_batches
==
1
||
pad_batches
==
batches
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
1
)
==
1
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
2
)
==
query_seq_len
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
3
)
==
key_seq_len
);
// Output
auto
act_options
=
input
.
options
().
requires_grad
(
false
);
torch
::
Tensor
softmax_results
=
torch
::
empty
({
batches
,
attn_heads
,
query_seq_len
,
key_seq_len
},
act_options
);
// Softmax Intermediate Result Ptr
void
*
input_ptr
=
static_cast
<
void
*>
(
input
.
data_ptr
());
void
*
mask_ptr
=
static_cast
<
void
*>
(
mask
.
data_ptr
());
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
DISPATCH_HALF_AND_BFLOAT
(
input
.
scalar_type
(),
"dispatch_scaled_masked_softmax_forward"
,
dispatch_scaled_masked_softmax_forward
<
scalar_t
,
scalar_t
,
float
>
(
reinterpret_cast
<
scalar_t
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
scalar_t
*>
(
input_ptr
),
reinterpret_cast
<
const
uint8_t
*>
(
mask_ptr
),
scale_factor
,
query_seq_len
,
key_seq_len
,
batches
,
attn_heads
,
pad_batches
);
);
return
softmax_results
;
}
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
const
&
output_grads_
,
torch
::
Tensor
const
&
softmax_results_
,
float
scale_factor
)
{
auto
output_grads
=
output_grads_
.
contiguous
();
auto
softmax_results
=
softmax_results_
.
contiguous
();
//output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const
int
batches
=
output_grads
.
size
(
0
);
const
int
attn_heads
=
output_grads
.
size
(
1
);
const
int
query_seq_len
=
output_grads
.
size
(
2
);
const
int
key_seq_len
=
output_grads
.
size
(
3
);
void
*
output_grads_ptr
=
static_cast
<
void
*>
(
output_grads
.
data_ptr
());
//Softmax Grad
DISPATCH_HALF_AND_BFLOAT
(
output_grads_
.
scalar_type
(),
"dispatch_scaled_masked_softmax_backward"
,
dispatch_scaled_masked_softmax_backward
<
scalar_t
,
scalar_t
,
float
>
(
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
reinterpret_cast
<
scalar_t
const
*>
(
softmax_results
.
data_ptr
()),
scale_factor
,
query_seq_len
,
key_seq_len
,
batches
,
attn_heads
);
);
//backward pass is completely in-place
return
output_grads
;
}
}
}
}
csrc/megatron/scaled_upper_triang_masked_softmax.cpp
0 → 100644
View file @
f79993d9
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace
multihead_attn
{
namespace
fused_softmax
{
namespace
scaled_upper_triang_masked_softmax
{
torch
::
Tensor
fwd_cuda
(
torch
::
Tensor
const
&
input
,
float
scale_factor
);
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
softmax_results
,
float
scale_factor
);
torch
::
Tensor
fwd
(
torch
::
Tensor
const
&
input
,
float
scale_factor
)
{
AT_ASSERTM
(
input
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
((
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
input
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
return
fwd_cuda
(
input
,
scale_factor
);
}
torch
::
Tensor
bwd
(
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
softmax_results
,
float
scale_factor
)
{
AT_ASSERTM
(
output_grads
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
((
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
AT_ASSERTM
((
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
return
bwd_cuda
(
output_grads
,
softmax_results
,
scale_factor
);
}
}
// end namespace scaled_upper_triang_masked_softmax
}
// end namespace fused_softmax
}
// end namespace multihead_attn
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
multihead_attn
::
fused_softmax
::
scaled_upper_triang_masked_softmax
::
fwd
,
"Self Multihead Attention scaled, time masked softmax -- Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
fused_softmax
::
scaled_upper_triang_masked_softmax
::
bwd
,
"Self Multihead Attention scaled, time masked softmax -- Backward."
);
}
csrc/megatron/scaled_upper_triang_masked_softmax.h
0 → 100644
View file @
f79993d9
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <assert.h>
#include <cuda_fp16.h>
#include <cfloat>
#include <limits>
#include <stdint.h>
#include <c10/macros/Macros.h>
namespace
{
template
<
typename
Datatype
,
int
ELEMENTS_PER_LDG
>
__device__
__inline__
void
copy_vector
(
Datatype
*
dst
,
const
Datatype
*
src
);
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
BFloat16
,
1
>
(
c10
::
BFloat16
*
dst
,
const
c10
::
BFloat16
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
BFloat16
,
4
>
(
c10
::
BFloat16
*
dst
,
const
c10
::
BFloat16
*
src
)
{
*
((
float2
*
)
dst
)
=
*
((
float2
*
)
src
);
}
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
Half
,
1
>
(
c10
::
Half
*
dst
,
const
c10
::
Half
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
Half
,
4
>
(
c10
::
Half
*
dst
,
const
c10
::
Half
*
src
)
{
*
((
float2
*
)
dst
)
=
*
((
float2
*
)
src
);
}
template
<
>
__device__
__inline__
void
copy_vector
<
uint8_t
,
1
>
(
uint8_t
*
dst
,
const
uint8_t
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
uint8_t
,
4
>
(
uint8_t
*
dst
,
const
uint8_t
*
src
)
{
*
((
half2
*
)
dst
)
=
*
((
half2
*
)
src
);
}
template
<
typename
Datatype
,
int
ELEMENTS_PER_LDG
>
__device__
__inline__
void
copy_zero_vector
(
Datatype
*
dst
);
template
<
>
__device__
__inline__
void
copy_zero_vector
<
c10
::
BFloat16
,
1
>
(
c10
::
BFloat16
*
dst
)
{
*
dst
=
0.0
;
}
template
<
>
__device__
__inline__
void
copy_zero_vector
<
c10
::
BFloat16
,
4
>
(
c10
::
BFloat16
*
dst
)
{
*
((
float2
*
)
dst
)
=
make_float2
(
0.0
f
,
0.0
f
);
}
template
<
>
__device__
__inline__
void
copy_zero_vector
<
c10
::
Half
,
1
>
(
c10
::
Half
*
dst
)
{
*
dst
=
0.0
;
}
template
<
>
__device__
__inline__
void
copy_zero_vector
<
c10
::
Half
,
4
>
(
c10
::
Half
*
dst
)
{
*
((
float2
*
)
dst
)
=
make_float2
(
0.0
f
,
0.0
f
);
}
int
log2_ceil
(
int
value
)
{
int
log2_value
=
0
;
while
((
1
<<
log2_value
)
<
value
)
++
log2_value
;
return
log2_value
;
}
template
<
typename
T
>
struct
Add
{
__device__
__forceinline__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
+
b
;
}
};
template
<
typename
T
>
struct
Max
{
__device__
__forceinline__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
<
b
?
b
:
a
;
}
};
template
<
typename
T
>
__device__
__forceinline__
T
WARP_SHFL_XOR_NATIVE
(
T
value
,
int
laneMask
,
int
width
=
warpSize
,
unsigned
int
mask
=
0xffffffff
)
{
#if CUDA_VERSION >= 9000
return
__shfl_xor_sync
(
mask
,
value
,
laneMask
,
width
);
#else
return
__shfl_xor
(
value
,
laneMask
,
width
);
#endif
}
template
<
typename
acc_t
,
int
WARP_BATCH
,
int
WARP_SIZE
,
template
<
typename
>
class
ReduceOp
>
__device__
__forceinline__
void
warp_reduce
(
acc_t
*
sum
)
{
ReduceOp
<
acc_t
>
r
;
#pragma unroll
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
acc_t
b
=
WARP_SHFL_XOR_NATIVE
(
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
=
r
(
sum
[
i
],
b
);
}
}
}
/*
* Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling
* 2) Implicit time (diagonal masking)
*/
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
log2_elements
>
__global__
void
scaled_upper_triang_masked_softmax_warp_forward
(
output_t
*
dst
,
const
input_t
*
src
,
const
acc_t
scale
,
int
micro_batch_size
,
int
stride
,
int
element_count
)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_forward_kernel.
constexpr
int
next_power_of_two
=
1
<<
log2_elements
;
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
ELEMENTS_PER_LDG_STG
=
(
WARP_ITERATIONS
<
4
)
?
1
:
4
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
y
+
threadIdx
.
y
)
*
gridDim
.
x
*
WARP_BATCH
+
blockIdx
.
x
;
int
local_seq
=
blockIdx
.
x
+
1
;
int
warp_iteration_limit
=
(
local_seq
+
ELEMENTS_PER_LDG_STG
*
WARP_SIZE
-
1
)
/
WARP_SIZE
;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int
local_batches
=
micro_batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
local_batches
=
WARP_BATCH
;
// there might be multiple batches per warp. compute the index within the batch
int
local_idx
=
threadIdx
.
x
;
src
+=
first_batch
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
dst
+=
first_batch
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
// load data from global memory
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
input_t
temp_data
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
local_seq
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_data
,
src
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
);
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
if
((
element_index
+
element
)
<
batch_element_count
)
{
elements
[
i
][
it
+
element
]
=
(
acc_t
)
temp_data
[
element
]
*
scale
;
}
else
{
elements
[
i
][
it
+
element
]
=
-
std
::
numeric_limits
<
acc_t
>::
infinity
();
}
}
}
else
{
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
elements
[
i
][
it
+
element
]
=
-
std
::
numeric_limits
<
acc_t
>::
infinity
();
}
}
}
}
// compute max_value
acc_t
max_value
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
max_value
[
i
]
=
elements
[
i
][
0
];
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
max_value
[
i
]
=
(
max_value
[
i
]
>
elements
[
i
][
it
])
?
max_value
[
i
]
:
elements
[
i
][
it
];
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Max
>
(
max_value
);
acc_t
sum
[
WARP_BATCH
]
{
0.0
f
};
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
if
(
it
<
warp_iteration_limit
)
{
elements
[
i
][
it
]
=
std
::
exp
((
elements
[
i
][
it
]
-
max_value
[
i
]));
sum
[
i
]
+=
elements
[
i
][
it
];
}
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Add
>
(
sum
);
// store result
output_t
out
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
local_seq
)
{
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
if
(
element_index
+
element
<
local_seq
)
{
out
[
element
]
=
elements
[
i
][
it
+
element
]
/
sum
[
i
];
}
else
{
out
[
element
]
=
0
;
}
}
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
dst
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
,
out
);
}
else
if
(
element_index
<
element_count
)
{
copy_zero_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
dst
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
);
}
else
{
break
;
}
}
}
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
log2_elements
>
__global__
void
scaled_upper_triang_masked_softmax_warp_backward
(
output_t
*
gradInput
,
input_t
*
grad
,
const
input_t
*
output
,
acc_t
scale
,
int
micro_batch_size
,
int
stride
,
int
element_count
)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_backward_kernel.
constexpr
int
next_power_of_two
=
1
<<
log2_elements
;
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
ELEMENTS_PER_LDG_STG
=
(
WARP_ITERATIONS
<
4
)
?
1
:
4
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
y
+
threadIdx
.
y
)
*
gridDim
.
x
*
WARP_BATCH
+
blockIdx
.
x
;
int
local_seq
=
blockIdx
.
x
+
1
;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int
local_batches
=
micro_batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
local_batches
=
WARP_BATCH
;
// there might be multiple batches per warp. compute the index within the batch
int
local_idx
=
threadIdx
.
x
;
// the first element to process by the current thread
int
thread_offset
=
first_batch
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
grad
+=
thread_offset
;
output
+=
thread_offset
;
gradInput
+=
thread_offset
;
// load data from global memory
acc_t
grad_reg
[
WARP_BATCH
][
WARP_ITERATIONS
]
{
0.0
f
};
acc_t
output_reg
[
WARP_BATCH
][
WARP_ITERATIONS
]
{
0.0
f
};
input_t
temp_grad
[
ELEMENTS_PER_LDG_STG
];
input_t
temp_output
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
local_seq
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_grad
,
grad
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
);
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_output
,
output
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
);
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
if
(
element_index
+
element
<
batch_element_count
)
{
output_reg
[
i
][
it
+
element
]
=
(
acc_t
)
temp_output
[
element
];
}
}
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
if
(
element_index
+
element
<
batch_element_count
)
{
grad_reg
[
i
][
it
+
element
]
=
(
acc_t
)
temp_grad
[
element
]
*
output_reg
[
i
][
it
+
element
];
}
}
}
}
}
acc_t
sum
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
=
grad_reg
[
i
][
0
];
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
sum
[
i
]
+=
grad_reg
[
i
][
it
];
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Add
>
(
sum
);
// store result
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
// compute gradients
output_t
out
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
out
[
element
]
=
(
output_t
)(
scale
*
(
grad_reg
[
i
][
it
+
element
]
-
output_reg
[
i
][
it
+
element
]
*
sum
[
i
]));
}
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
gradInput
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
,
out
);
}
}
}
}
}
// end of anonymous namespace
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
void
dispatch_scaled_upper_triang_masked_softmax_forward
(
output_t
*
dst
,
const
input_t
*
src
,
const
input_t
scale
,
int
softmax_elements
,
int
softmax_elements_stride
,
int
attn_batches
)
{
TORCH_INTERNAL_ASSERT
(
softmax_elements
>=
0
&&
softmax_elements
<=
2048
);
if
(
softmax_elements
==
0
)
{
return
;
}
else
{
int
log2_elements
=
log2_ceil
(
softmax_elements
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
seq_len
=
softmax_elements
;
int
batch_count
=
attn_batches
*
seq_len
;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
TORCH_INTERNAL_ASSERT
(
attn_batches
%
batches_per_block
==
0
);
int
blocks_per_seq
=
attn_batches
/
batches_per_block
;
dim3
blocks
(
seq_len
,
blocks_per_seq
,
1
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch
(
log2_elements
)
{
case
0
:
// 1
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
0
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
1
:
// 2
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
1
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
2
:
// 4
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
3
:
// 8
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
3
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
4
:
// 16
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
4
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
5
:
// 32
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
5
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
6
:
// 64
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
6
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
7
:
// 128
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
7
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
8
:
// 256
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
8
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
9
:
// 512
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
9
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
10
:
// 1024
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
10
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
11
:
// 2048
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
11
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
default:
break
;
}
}
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
void
dispatch_scaled_upper_triang_masked_softmax_backward
(
output_t
*
grad_input
,
input_t
*
grad
,
const
input_t
*
output
,
const
acc_t
scale
,
int
softmax_elements
,
int
softmax_elements_stride
,
int
attn_batches
)
{
TORCH_INTERNAL_ASSERT
(
softmax_elements
>=
0
&&
softmax_elements
<=
2048
);
if
(
softmax_elements
==
0
)
{
return
;
}
else
{
int
log2_elements
=
log2_ceil
(
softmax_elements
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
seq_len
=
softmax_elements
;
int
batch_count
=
attn_batches
*
seq_len
;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
TORCH_INTERNAL_ASSERT
(
attn_batches
%
batches_per_block
==
0
);
int
blocks_per_seq
=
attn_batches
/
batches_per_block
;
dim3
blocks
(
seq_len
,
blocks_per_seq
,
1
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch
(
log2_elements
)
{
case
0
:
// 1
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
0
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
1
:
// 2
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
1
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
2
:
// 4
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
3
:
// 8
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
3
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
4
:
// 16
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
4
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
5
:
// 32
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
5
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
6
:
// 64
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
6
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
7
:
// 128
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
7
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
8
:
// 256
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
8
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
9
:
// 512
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
9
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
10
:
// 1024
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
10
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
11
:
// 2048
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
11
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
default:
break
;
}
}
}
csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu
0 → 100644
View file @
f79993d9
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "scaled_upper_triang_masked_softmax.h"
#include "type_shim.h"
namespace
multihead_attn
{
namespace
fused_softmax
{
namespace
scaled_upper_triang_masked_softmax
{
torch
::
Tensor
fwd_cuda
(
torch
::
Tensor
const
&
input
,
float
scale_factor
)
{
// input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
const
int
attn_batches
=
input
.
size
(
0
);
const
int
seq_len
=
input
.
size
(
1
);
TORCH_INTERNAL_ASSERT
(
seq_len
<=
2048
);
// Output
auto
act_options
=
input
.
options
().
requires_grad
(
false
);
torch
::
Tensor
softmax_results
=
torch
::
empty
({
attn_batches
,
seq_len
,
seq_len
},
act_options
);
// Softmax Intermediate Result Ptr
void
*
input_ptr
=
static_cast
<
void
*>
(
input
.
data_ptr
());
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
DISPATCH_HALF_AND_BFLOAT
(
input
.
scalar_type
(),
"dispatch_scaled_upper_triang_masked_softmax_forward"
,
dispatch_scaled_upper_triang_masked_softmax_forward
<
scalar_t
,
scalar_t
,
float
>
(
reinterpret_cast
<
scalar_t
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
scalar_t
*>
(
input_ptr
),
scale_factor
,
seq_len
,
seq_len
,
attn_batches
);
);
return
softmax_results
;
}
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
const
&
output_grads_
,
torch
::
Tensor
const
&
softmax_results_
,
float
scale_factor
)
{
auto
output_grads
=
output_grads_
.
contiguous
();
auto
softmax_results
=
softmax_results_
.
contiguous
();
//output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
const
int
attn_batches
=
output_grads
.
size
(
0
);
const
int
seq_len
=
output_grads
.
size
(
1
);
TORCH_INTERNAL_ASSERT
(
output_grads
.
size
(
1
)
==
output_grads
.
size
(
2
));
void
*
output_grads_ptr
=
static_cast
<
void
*>
(
output_grads
.
data_ptr
());
//Softmax Grad
DISPATCH_HALF_AND_BFLOAT
(
output_grads_
.
scalar_type
(),
"dispatch_scaled_upper_triang_masked_softmax_backward"
,
dispatch_scaled_upper_triang_masked_softmax_backward
<
scalar_t
,
scalar_t
,
float
>
(
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
reinterpret_cast
<
scalar_t
const
*>
(
softmax_results
.
data_ptr
()),
scale_factor
,
seq_len
,
seq_len
,
attn_batches
);
);
//backward pass is completely in-place
return
output_grads
;
}
}
}
}
csrc/mlp.cpp
View file @
f79993d9
...
@@ -4,15 +4,7 @@
...
@@ -4,15 +4,7 @@
#include <stdio.h>
#include <stdio.h>
int
SizeTToInt
(
size_t
data
)
size_t
get_mlp_reserved_space
(
int64_t
batch_size
,
int
num_layers
,
const
int
*
output_features
);
{
if
(
data
>
std
::
numeric_limits
<
int
>::
max
())
{
throw
std
::
runtime_error
(
"Invalid cast."
);
}
return
static_cast
<
int
>
(
data
);
}
size_t
get_mlp_reserved_space
(
int
batch_size
,
int
num_layers
,
const
int
*
output_features
);
template
<
typename
T
>
template
<
typename
T
>
size_t
get_mlp_bp_workspace_in_bytes
(
int
batch_size
,
int
num_layers
,
const
int
*
output_features
);
size_t
get_mlp_bp_workspace_in_bytes
(
int
batch_size
,
int
num_layers
,
const
int
*
output_features
);
...
@@ -29,7 +21,8 @@ int mlp_fp(
...
@@ -29,7 +21,8 @@ int mlp_fp(
T
*
Y
,
T
*
Y
,
T
*
reserved_space
,
T
*
reserved_space
,
int
use_bias
,
int
use_bias
,
int
activation
);
int
activation
,
void
*
lt_workspace
);
template
<
typename
T
>
template
<
typename
T
>
int
mlp_bp
(
int
mlp_bp
(
...
@@ -68,9 +61,10 @@ std::vector<at::Tensor> mlp_forward(int use_bias, int activation, std::vector<at
...
@@ -68,9 +61,10 @@ std::vector<at::Tensor> mlp_forward(int use_bias, int activation, std::vector<at
auto
reserved_size
=
get_mlp_reserved_space
(
batch_size
,
num_layers
,
output_features
.
data
());
auto
reserved_size
=
get_mlp_reserved_space
(
batch_size
,
num_layers
,
output_features
.
data
());
// create output/workspace tensor
// create output/workspace tensor
// TODO(deyuf): just get buffer?
auto
out
=
at
::
empty
({
batch_size
,
output_features
.
back
()},
inputs
[
0
].
type
());
auto
out
=
at
::
empty
({
batch_size
,
output_features
.
back
()},
inputs
[
0
].
type
());
auto
reserved_space
=
at
::
empty
({
SizeTToInt
(
reserved_size
)},
inputs
[
0
].
type
());
auto
reserved_space
=
at
::
empty
({
reserved_size
},
inputs
[
0
].
type
());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto
lt_workspace
=
at
::
empty
({
1
<<
22
},
inputs
[
0
].
type
());
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
inputs
[
0
].
type
(),
"mlp_forward"
,
[
&
]
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
inputs
[
0
].
type
(),
"mlp_forward"
,
[
&
]
{
std
::
vector
<
scalar_t
*>
w_ptr
;
std
::
vector
<
scalar_t
*>
w_ptr
;
...
@@ -92,7 +86,8 @@ std::vector<at::Tensor> mlp_forward(int use_bias, int activation, std::vector<at
...
@@ -92,7 +86,8 @@ std::vector<at::Tensor> mlp_forward(int use_bias, int activation, std::vector<at
out
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
reserved_space
.
data_ptr
<
scalar_t
>
(),
reserved_space
.
data_ptr
<
scalar_t
>
(),
use_bias
,
use_bias
,
activation
);
activation
,
(
void
*
)
(
lt_workspace
.
data_ptr
<
scalar_t
>
()));
});
});
return
{
out
,
reserved_space
};
return
{
out
,
reserved_space
};
...
@@ -114,7 +109,6 @@ std::vector<at::Tensor> mlp_backward(
...
@@ -114,7 +109,6 @@ std::vector<at::Tensor> mlp_backward(
auto
batch_size
=
inputs
[
0
].
size
(
0
);
auto
batch_size
=
inputs
[
0
].
size
(
0
);
auto
input_features
=
inputs
[
0
].
size
(
1
);
auto
input_features
=
inputs
[
0
].
size
(
1
);
// TODO: not creating empty tensor for it?
bool
requires_grad
=
inputs
[
0
].
requires_grad
();
bool
requires_grad
=
inputs
[
0
].
requires_grad
();
std
::
vector
<
int
>
output_features
;
std
::
vector
<
int
>
output_features
;
...
@@ -122,7 +116,6 @@ std::vector<at::Tensor> mlp_backward(
...
@@ -122,7 +116,6 @@ std::vector<at::Tensor> mlp_backward(
output_features
.
push_back
(
inputs
[
i
+
1
].
size
(
0
));
output_features
.
push_back
(
inputs
[
i
+
1
].
size
(
0
));
}
}
// create outputs, length of inputs
// create outputs, length of inputs
// TODO: not create bias if not needed
std
::
vector
<
at
::
Tensor
>
outputs
;
std
::
vector
<
at
::
Tensor
>
outputs
;
for
(
int
i
=
0
;
i
<
inputs
.
size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
inputs
.
size
();
i
++
)
{
outputs
.
push_back
(
at
::
empty
(
inputs
[
i
].
sizes
(),
inputs
[
i
].
type
()));
// clone for testing now
outputs
.
push_back
(
at
::
empty
(
inputs
[
i
].
sizes
(),
inputs
[
i
].
type
()));
// clone for testing now
...
@@ -142,7 +135,7 @@ std::vector<at::Tensor> mlp_backward(
...
@@ -142,7 +135,7 @@ std::vector<at::Tensor> mlp_backward(
get_mlp_bp_workspace_in_bytes
<
scalar_t
>
(
batch_size
,
num_layers
,
output_features
.
data
());
get_mlp_bp_workspace_in_bytes
<
scalar_t
>
(
batch_size
,
num_layers
,
output_features
.
data
());
// auto work_space = at::empty({work_size*4}, at::kByte);
// auto work_space = at::empty({work_size*4}, at::kByte);
auto
work_space
=
at
::
empty
({
SizeTToInt
(
work_size
/
sizeof
(
scalar_t
)
)
},
inputs
[
0
].
type
());
auto
work_space
=
at
::
empty
({
work_size
/
sizeof
(
scalar_t
)},
inputs
[
0
].
type
());
auto
result
=
mlp_bp
<
scalar_t
>
(
auto
result
=
mlp_bp
<
scalar_t
>
(
inputs
[
0
].
data_ptr
<
scalar_t
>
(),
inputs
[
0
].
data_ptr
<
scalar_t
>
(),
...
@@ -170,3 +163,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -170,3 +163,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"forward"
,
&
mlp_forward
,
"MLP forward"
);
m
.
def
(
"forward"
,
&
mlp_forward
,
"MLP forward"
);
m
.
def
(
"backward"
,
&
mlp_backward
,
"MLP backward"
);
m
.
def
(
"backward"
,
&
mlp_backward
,
"MLP backward"
);
}
}
csrc/mlp_cuda.cu
View file @
f79993d9
...
@@ -10,6 +10,10 @@
...
@@ -10,6 +10,10 @@
#include <cublas_v2.h>
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
// includes cublaslt
#include <cublasLt.h>
#endif
// constants for fused bias+relu kernel
// constants for fused bias+relu kernel
#define BIAS_RELU_FW_NTHREADS 128 // forward number of thread per block
#define BIAS_RELU_FW_NTHREADS 128 // forward number of thread per block
#define BIAS_RELU_BW_NTHREADS_X 32 // backward number of thread in feature dim
#define BIAS_RELU_BW_NTHREADS_X 32 // backward number of thread in feature dim
...
@@ -249,6 +253,268 @@ cublasStatus_t mlp_gemm(
...
@@ -249,6 +253,268 @@ cublasStatus_t mlp_gemm(
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
#endif
#endif
}
}
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
int
mlp_gemm_lt
(
cublasLtHandle_t
ltHandle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
float
*
alpha
,
/* host pointer */
const
at
::
Half
*
A
,
int
lda
,
const
at
::
Half
*
B
,
int
ldb
,
float
*
beta
,
/* host pointer */
at
::
Half
*
C
,
int
ldc
,
void
*
workspace
,
size_t
workspaceSize
,
cudaStream_t
stream
,
bool
use_bias
,
bool
use_relu
,
const
void
*
bias
)
{
cublasStatus_t
status
=
CUBLAS_STATUS_SUCCESS
;
cublasLtMatmulDescOpaque_t
operationDesc
=
{};
cublasLtMatrixLayoutOpaque_t
Adesc
=
{},
Bdesc
=
{},
Cdesc
=
{};
cublasLtMatmulPreferenceOpaque_t
preference
=
{};
int
returnedResults
=
0
;
cublasLtMatmulHeuristicResult_t
heuristicResult
=
{};
cublasLtEpilogue_t
epilogue
=
CUBLASLT_EPILOGUE_DEFAULT
;
// Create operation descriptor; see cublasLtMatmulDescAttributes_t
// for details about defaults; here we just set the transforms for
// A and B.
status
=
cublasLtMatmulDescInit
(
&
operationDesc
,
CUBLAS_COMPUTE_32F
,
CUDA_R_32F
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_TRANSA
,
&
transa
,
sizeof
(
transa
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_TRANSB
,
&
transb
,
sizeof
(
transa
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
if
(
use_bias
)
{
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_BIAS_POINTER
,
&
bias
,
sizeof
(
bias
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
goto
CLEANUP
;
}
if
(
use_relu
)
{
epilogue
=
CUBLASLT_EPILOGUE_RELU_BIAS
;
}
else
{
epilogue
=
CUBLASLT_EPILOGUE_BIAS
;
}
}
else
{
if
(
use_relu
)
{
epilogue
=
CUBLASLT_EPILOGUE_RELU
;
}
}
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_EPILOGUE
,
&
epilogue
,
sizeof
(
epilogue
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
goto
CLEANUP
;
}
// Create matrix descriptors. Not setting any extra attributes.
status
=
cublasLtMatrixLayoutInit
(
&
Adesc
,
CUDA_R_16F
,
transa
==
CUBLAS_OP_N
?
m
:
k
,
transa
==
CUBLAS_OP_N
?
k
:
m
,
lda
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatrixLayoutInit
(
&
Bdesc
,
CUDA_R_16F
,
transb
==
CUBLAS_OP_N
?
k
:
n
,
transb
==
CUBLAS_OP_N
?
n
:
k
,
ldb
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatrixLayoutInit
(
&
Cdesc
,
CUDA_R_16F
,
m
,
n
,
ldc
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
// Create preference handle; In general, extra attributes can be
// used here to disable tensor ops or to make sure algo selected
// will work with badly aligned A, B, C. However, for simplicity
// here we assume A,B,C are always well aligned (e.g., directly
// come from cudaMalloc)
status
=
cublasLtMatmulPreferenceInit
(
&
preference
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatmulPreferenceSetAttribute
(
&
preference
,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES
,
&
workspaceSize
,
sizeof
(
workspaceSize
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
// We just need the best available heuristic to try and run matmul.
// There is no guarantee that this will work. For example, if A is
// badly aligned, you can request more (e.g. 32) algos and try to
// run them one by one until something works.
status
=
cublasLtMatmulAlgoGetHeuristic
(
ltHandle
,
&
operationDesc
,
&
Adesc
,
&
Bdesc
,
&
Cdesc
,
&
Cdesc
,
&
preference
,
1
,
&
heuristicResult
,
&
returnedResults
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
if
(
returnedResults
==
0
)
{
status
=
CUBLAS_STATUS_NOT_SUPPORTED
;
goto
CLEANUP
;
}
status
=
cublasLtMatmul
(
ltHandle
,
&
operationDesc
,
alpha
,
A
,
&
Adesc
,
B
,
&
Bdesc
,
beta
,
C
,
&
Cdesc
,
C
,
&
Cdesc
,
&
heuristicResult
.
algo
,
workspace
,
workspaceSize
,
stream
);
CLEANUP:
// Descriptors are no longer needed as all GPU work was already
// enqueued.
return
status
==
CUBLAS_STATUS_SUCCESS
?
0
:
1
;
}
int
mlp_gemm_lt
(
cublasLtHandle_t
ltHandle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
float
*
alpha
,
/* host pointer */
const
double
*
A
,
int
lda
,
const
double
*
B
,
int
ldb
,
float
*
beta
,
/* host pointer */
double
*
C
,
int
ldc
,
void
*
workspace
,
size_t
workspaceSize
,
cudaStream_t
stream
,
bool
use_bias
,
bool
use_relu
,
const
void
*
bias
)
{
return
1
;
}
int
mlp_gemm_lt
(
cublasLtHandle_t
ltHandle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
float
*
alpha
,
/* host pointer */
const
float
*
A
,
int
lda
,
const
float
*
B
,
int
ldb
,
float
*
beta
,
/* host pointer */
float
*
C
,
int
ldc
,
void
*
workspace
,
size_t
workspaceSize
,
cudaStream_t
stream
,
bool
use_bias
,
bool
use_relu
,
const
void
*
bias
)
{
cublasStatus_t
status
=
CUBLAS_STATUS_SUCCESS
;
cublasLtMatmulDescOpaque_t
operationDesc
=
{};
cublasLtMatrixLayoutOpaque_t
Adesc
=
{},
Bdesc
=
{},
Cdesc
=
{};
cublasLtMatmulPreferenceOpaque_t
preference
=
{};
int
returnedResults
=
0
;
cublasLtMatmulHeuristicResult_t
heuristicResult
=
{};
cublasLtEpilogue_t
epilogue
=
CUBLASLT_EPILOGUE_DEFAULT
;
// Create operation descriptor; see cublasLtMatmulDescAttributes_t
// for details about defaults; here we just set the transforms for
// A and B.
status
=
cublasLtMatmulDescInit
(
&
operationDesc
,
CUBLAS_COMPUTE_32F
,
CUDA_R_32F
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_TRANSA
,
&
transa
,
sizeof
(
transa
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_TRANSB
,
&
transb
,
sizeof
(
transa
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
if
(
use_bias
)
{
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_BIAS_POINTER
,
&
bias
,
sizeof
(
bias
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
goto
CLEANUP
;
}
if
(
use_relu
)
{
epilogue
=
CUBLASLT_EPILOGUE_RELU_BIAS
;
}
else
{
epilogue
=
CUBLASLT_EPILOGUE_BIAS
;
}
}
else
{
if
(
use_relu
)
{
epilogue
=
CUBLASLT_EPILOGUE_RELU
;
}
}
status
=
cublasLtMatmulDescSetAttribute
(
&
operationDesc
,
CUBLASLT_MATMUL_DESC_EPILOGUE
,
&
epilogue
,
sizeof
(
epilogue
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
goto
CLEANUP
;
}
// Create matrix descriptors. Not setting any extra attributes.
status
=
cublasLtMatrixLayoutInit
(
&
Adesc
,
CUDA_R_32F
,
transa
==
CUBLAS_OP_N
?
m
:
k
,
transa
==
CUBLAS_OP_N
?
k
:
m
,
lda
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatrixLayoutInit
(
&
Bdesc
,
CUDA_R_32F
,
transb
==
CUBLAS_OP_N
?
k
:
n
,
transb
==
CUBLAS_OP_N
?
n
:
k
,
ldb
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatrixLayoutInit
(
&
Cdesc
,
CUDA_R_32F
,
m
,
n
,
ldc
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
// Create preference handle; In general, extra attributes can be
// used here to disable tensor ops or to make sure algo selected
// will work with badly aligned A, B, C. However, for simplicity
// here we assume A,B,C are always well aligned (e.g., directly
// come from cudaMalloc)
status
=
cublasLtMatmulPreferenceInit
(
&
preference
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
status
=
cublasLtMatmulPreferenceSetAttribute
(
&
preference
,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES
,
&
workspaceSize
,
sizeof
(
workspaceSize
));
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
// We just need the best available heuristic to try and run matmul.
// There is no guarantee that this will work. For example, if A is
// badly aligned, you can request more (e.g. 32) algos and try to
// run them one by one until something works.
status
=
cublasLtMatmulAlgoGetHeuristic
(
ltHandle
,
&
operationDesc
,
&
Adesc
,
&
Bdesc
,
&
Cdesc
,
&
Cdesc
,
&
preference
,
1
,
&
heuristicResult
,
&
returnedResults
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
goto
CLEANUP
;
if
(
returnedResults
==
0
)
{
status
=
CUBLAS_STATUS_NOT_SUPPORTED
;
goto
CLEANUP
;
}
status
=
cublasLtMatmul
(
ltHandle
,
&
operationDesc
,
alpha
,
A
,
&
Adesc
,
B
,
&
Bdesc
,
beta
,
C
,
&
Cdesc
,
C
,
&
Cdesc
,
&
heuristicResult
.
algo
,
workspace
,
workspaceSize
,
stream
);
CLEANUP:
// Descriptors are no longer needed as all GPU work was already
// enqueued.
return
status
==
CUBLAS_STATUS_SUCCESS
?
0
:
1
;
}
#endif
// Bias ADD. Assume input X is [features x batch size], column major.
// Bias ADD. Assume input X is [features x batch size], column major.
// Bias is one 'features' long vector, with implicit broadcast.
// Bias is one 'features' long vector, with implicit broadcast.
...
@@ -538,7 +804,7 @@ void get_biasAddRelu_bprop_grid_size(
...
@@ -538,7 +804,7 @@ void get_biasAddRelu_bprop_grid_size(
// Get number of SMs for efficient reduction.
// Get number of SMs for efficient reduction.
int
num_SMs
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
;
int
num_SMs
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
;
// can switch to occupancy calculation. use 4 below now for sm_70
// can switch to occupancy calculation. use 4 below now for sm_70
int
max_blocks_y
=
num_SMs
*
4
/
(
*
grid_x
);
int
max_blocks_y
=
(
num_SMs
*
4
+
(
*
grid_x
)
-
1
)
/
(
*
grid_x
);
// block_y should be from minimal work per thread
// block_y should be from minimal work per thread
int
nRedSplits
=
(
batch_size
+
block_y
-
1
)
/
block_y
;
int
nRedSplits
=
(
batch_size
+
block_y
-
1
)
/
block_y
;
// increase number of elem per thread redcution to not launch more than enough
// increase number of elem per thread redcution to not launch more than enough
...
@@ -583,7 +849,7 @@ __global__ void biasAdd_bprop(
...
@@ -583,7 +849,7 @@ __global__ void biasAdd_bprop(
int
nidx
=
0
;
int
nidx
=
0
;
// Handle non-multiple of UNROLL_FACTOR residue
// Handle non-multiple of UNROLL_FACTOR residue
for
(;
nidx
<
nSpan
%
UNROLL_FACTOR
;
nidx
++
)
{
for
(;
nidx
<
nSpan
%
UNROLL_FACTOR
;
nidx
++
)
{
int
row
,
col
,
flat_idx
;
int
64_t
row
,
col
,
flat_idx
;
row
=
f
;
row
=
f
;
col
=
nStart
+
nidx
;
col
=
nStart
+
nidx
;
flat_idx
=
col
*
features
+
row
;
flat_idx
=
col
*
features
+
row
;
...
@@ -592,7 +858,7 @@ __global__ void biasAdd_bprop(
...
@@ -592,7 +858,7 @@ __global__ void biasAdd_bprop(
// Handle meat of work
// Handle meat of work
for
(;
(
nidx
+
UNROLL_FACTOR
-
1
)
<
nSpan
;
nidx
+=
UNROLL_FACTOR
)
{
for
(;
(
nidx
+
UNROLL_FACTOR
-
1
)
<
nSpan
;
nidx
+=
UNROLL_FACTOR
)
{
int
row
,
col
,
flat_idx
;
int
64_t
row
,
col
,
flat_idx
;
row
=
f
;
row
=
f
;
col
=
nStart
+
nidx
;
col
=
nStart
+
nidx
;
flat_idx
=
col
*
features
+
row
;
flat_idx
=
col
*
features
+
row
;
...
@@ -865,7 +1131,6 @@ __global__ void biasAddRelu_bprop_aligned(
...
@@ -865,7 +1131,6 @@ __global__ void biasAddRelu_bprop_aligned(
}
}
// block result is in db_local now for all threadIdx.y == 0
// block result is in db_local now for all threadIdx.y == 0
// TODO: maybe not useful early exit here
if
(
gridDim
.
y
==
1
)
{
if
(
gridDim
.
y
==
1
)
{
#pragma unroll
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
){
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
){
...
@@ -932,7 +1197,7 @@ void get_y_offsets(
...
@@ -932,7 +1197,7 @@ void get_y_offsets(
}
}
// Returns the reserved space (in elements) needed for the MLP
// Returns the reserved space (in elements) needed for the MLP
size_t
get_mlp_reserved_space
(
int
batch_size
,
int
num_layers
,
const
int
*
output_features
)
{
size_t
get_mlp_reserved_space
(
int
64_t
batch_size
,
int
num_layers
,
const
int
*
output_features
)
{
size_t
res_space
=
0
;
size_t
res_space
=
0
;
// Need to store output of every intermediate MLP - size equal to output_features[i] * batch_size
// Need to store output of every intermediate MLP - size equal to output_features[i] * batch_size
// for all 'i' in [0, num_layers-1)
// for all 'i' in [0, num_layers-1)
...
@@ -943,7 +1208,7 @@ size_t get_mlp_reserved_space(int batch_size, int num_layers, const int* output_
...
@@ -943,7 +1208,7 @@ size_t get_mlp_reserved_space(int batch_size, int num_layers, const int* output_
}
}
// Returns the size of all fprop activations combined
// Returns the size of all fprop activations combined
size_t
get_all_activations_size
(
int
batch_size
,
int
num_layers
,
const
int
*
output_features
)
{
size_t
get_all_activations_size
(
int
64_t
batch_size
,
int
num_layers
,
const
int
*
output_features
)
{
size_t
acts_size
=
0
;
size_t
acts_size
=
0
;
for
(
int
l
=
0
;
l
<
num_layers
;
l
++
)
{
for
(
int
l
=
0
;
l
<
num_layers
;
l
++
)
{
acts_size
+=
output_features
[
l
]
*
batch_size
;
acts_size
+=
output_features
[
l
]
*
batch_size
;
...
@@ -1064,7 +1329,8 @@ int mlp_fp(
...
@@ -1064,7 +1329,8 @@ int mlp_fp(
T
*
Y
,
T
*
Y
,
T
*
reserved_space
,
T
*
reserved_space
,
int
use_bias
,
int
use_bias
,
int
activation
)
{
int
activation
,
void
*
lt_workspace
)
{
T
*
weight
,
*
input
,
*
output
,
*
bias
;
T
*
weight
,
*
input
,
*
output
,
*
bias
;
T
*
reserved_space_x
,
*
reserved_space_y
;
T
*
reserved_space_x
,
*
reserved_space_y
;
reserved_space_x
=
NULL
;
reserved_space_x
=
NULL
;
...
@@ -1089,9 +1355,40 @@ int mlp_fp(
...
@@ -1089,9 +1355,40 @@ int mlp_fp(
float
one
=
1.
f
;
float
one
=
1.
f
;
float
zero
=
0.
f
;
float
zero
=
0.
f
;
cublasStatus_t
cublas_status
;
// try with cublaslt first for supported case with valid handle
// Call GEMM: fprop is Y = W'X
int
cublaslt_status
=
1
;
cublas_status
=
mlp_gemm
(
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
if
(
activation
<
1
){
cublaslt_status
=
mlp_gemm_lt
(
//ltHandle,
(
cublasLtHandle_t
)
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
ofeat
,
batch_size
,
ifeat
,
&
one
,
weight
,
ifeat
,
input
,
ifeat
,
&
zero
,
output
,
ofeat
,
lt_workspace
,
1
<<
22
,
stream
,
use_bias
==
1
,
activation
==
1
,
bias
);
}
#endif
// if cublaslt failed or not executed, fallback to cublas
if
(
cublaslt_status
!=
0
)
{
cublasStatus_t
cublas_status
;
// Call GEMM: fprop is Y = W'X
cublas_status
=
mlp_gemm
(
handle
,
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
...
@@ -1107,39 +1404,39 @@ int mlp_fp(
...
@@ -1107,39 +1404,39 @@ int mlp_fp(
output
,
output
,
ofeat
);
ofeat
);
if
(
cublas_status
!=
CUBLAS_STATUS_SUCCESS
)
{
if
(
cublas_status
!=
CUBLAS_STATUS_SUCCESS
)
{
printf
(
"GEMM fprop failed with %d
\n
"
,
cublas_status
);
printf
(
"GEMM fprop failed with %d
\n
"
,
cublas_status
);
return
1
;
return
1
;
}
const
uint
&
input_size
=
ofeat
;
int
num_blocks
=
0
;
int
num_SMs
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
;
// Call biasReLU
if
(
use_bias
==
1
)
{
if
(
activation
==
0
)
{
// no activation
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
num_blocks
,
biasAdd_fprop
<
T
>
,
BIAS_RELU_FW_NTHREADS
,
0
);
biasAdd_fprop
<<<
num_SMs
*
num_blocks
,
BIAS_RELU_FW_NTHREADS
,
0
,
stream
>>>
(
output
,
bias
,
batch_size
,
input_size
);
}
else
if
(
activation
==
1
)
{
// relu
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
num_blocks
,
biasAddRelu_fprop
<
T
>
,
BIAS_RELU_FW_NTHREADS
,
0
);
biasAddRelu_fprop
<<<
num_SMs
*
num_blocks
,
BIAS_RELU_FW_NTHREADS
,
0
,
stream
>>>
(
output
,
bias
,
batch_size
,
input_size
);
}
else
if
(
activation
==
2
)
{
// sigmoid
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
num_blocks
,
biasAdd_fprop
<
T
>
,
BIAS_RELU_FW_NTHREADS
,
0
);
biasAdd_fprop
<<<
num_SMs
*
num_blocks
,
BIAS_RELU_FW_NTHREADS
,
0
,
stream
>>>
(
output
,
bias
,
batch_size
,
input_size
);
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
num_blocks
,
Sigmoid_fprop
<
T
>
,
BIAS_RELU_FW_NTHREADS
,
0
);
Sigmoid_fprop
<<<
num_SMs
*
num_blocks
,
BIAS_RELU_FW_NTHREADS
,
0
,
stream
>>>
(
output
,
batch_size
,
input_size
);
}
}
}
else
{
// don't need to do anything in case of no activation and no bias
const
uint
&
input_size
=
ofeat
;
if
(
activation
==
1
)
{
// relu
int
num_blocks
=
0
;
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
num_blocks
,
Relu_fprop
<
T
>
,
BIAS_RELU_FW_NTHREADS
,
0
);
int
num_SMs
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
;
Relu_fprop
<<<
num_SMs
*
num_blocks
,
BIAS_RELU_FW_NTHREADS
,
0
,
stream
>>>
(
output
,
batch_size
,
input_size
);
// Call biasReLU
}
else
if
(
activation
==
2
)
{
// sigmoid
if
(
use_bias
==
1
)
{
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
num_blocks
,
Sigmoid_fprop
<
T
>
,
BIAS_RELU_FW_NTHREADS
,
0
);
if
(
activation
==
0
)
{
// no activation
Sigmoid_fprop
<<<
num_SMs
*
num_blocks
,
BIAS_RELU_FW_NTHREADS
,
0
,
stream
>>>
(
output
,
batch_size
,
input_size
);
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
num_blocks
,
biasAdd_fprop
<
T
>
,
BIAS_RELU_FW_NTHREADS
,
0
);
biasAdd_fprop
<<<
num_SMs
*
num_blocks
,
BIAS_RELU_FW_NTHREADS
,
0
,
stream
>>>
(
output
,
bias
,
batch_size
,
input_size
);
}
else
if
(
activation
==
1
)
{
// relu
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
num_blocks
,
biasAddRelu_fprop
<
T
>
,
BIAS_RELU_FW_NTHREADS
,
0
);
biasAddRelu_fprop
<<<
num_SMs
*
num_blocks
,
BIAS_RELU_FW_NTHREADS
,
0
,
stream
>>>
(
output
,
bias
,
batch_size
,
input_size
);
}
else
if
(
activation
==
2
)
{
// sigmoid
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
num_blocks
,
biasAdd_fprop
<
T
>
,
BIAS_RELU_FW_NTHREADS
,
0
);
biasAdd_fprop
<<<
num_SMs
*
num_blocks
,
BIAS_RELU_FW_NTHREADS
,
0
,
stream
>>>
(
output
,
bias
,
batch_size
,
input_size
);
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
num_blocks
,
Sigmoid_fprop
<
T
>
,
BIAS_RELU_FW_NTHREADS
,
0
);
Sigmoid_fprop
<<<
num_SMs
*
num_blocks
,
BIAS_RELU_FW_NTHREADS
,
0
,
stream
>>>
(
output
,
batch_size
,
input_size
);
}
}
else
{
// don't need to do anything in case of no activation and no bias
if
(
activation
==
1
)
{
// relu
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
num_blocks
,
Relu_fprop
<
T
>
,
BIAS_RELU_FW_NTHREADS
,
0
);
Relu_fprop
<<<
num_SMs
*
num_blocks
,
BIAS_RELU_FW_NTHREADS
,
0
,
stream
>>>
(
output
,
batch_size
,
input_size
);
}
else
if
(
activation
==
2
)
{
// sigmoid
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
num_blocks
,
Sigmoid_fprop
<
T
>
,
BIAS_RELU_FW_NTHREADS
,
0
);
Sigmoid_fprop
<<<
num_SMs
*
num_blocks
,
BIAS_RELU_FW_NTHREADS
,
0
,
stream
>>>
(
output
,
batch_size
,
input_size
);
}
}
}
}
}
// Set current output as next layer input
// Set current output as next layer input
reserved_space_x
=
reserved_space_y
;
reserved_space_x
=
reserved_space_y
;
// Set next layer output
// Set next layer output
...
@@ -1366,7 +1663,8 @@ template int mlp_fp<float>(
...
@@ -1366,7 +1663,8 @@ template int mlp_fp<float>(
float
*
Y
,
float
*
Y
,
float
*
reserved_space
,
float
*
reserved_space
,
int
use_bias
,
int
use_bias
,
int
activation
);
int
activation
,
void
*
lt_workspace
);
template
int
mlp_bp
<
float
>(
template
int
mlp_bp
<
float
>(
float
*
X
,
float
*
X
,
...
@@ -1397,7 +1695,8 @@ template int mlp_fp<at::Half>(
...
@@ -1397,7 +1695,8 @@ template int mlp_fp<at::Half>(
at
::
Half
*
Y
,
at
::
Half
*
Y
,
at
::
Half
*
reserved_space
,
at
::
Half
*
reserved_space
,
int
use_bias
,
int
use_bias
,
int
activation
);
int
activation
,
void
*
lt_workspace
);
template
int
mlp_bp
<
at
::
Half
>(
template
int
mlp_bp
<
at
::
Half
>(
at
::
Half
*
X
,
at
::
Half
*
X
,
...
@@ -1428,7 +1727,8 @@ template int mlp_fp<double>(
...
@@ -1428,7 +1727,8 @@ template int mlp_fp<double>(
double
*
Y
,
double
*
Y
,
double
*
reserved_space
,
double
*
reserved_space
,
int
use_bias
,
int
use_bias
,
int
activation
);
int
activation
,
void
*
lt_workspace
);
template
int
mlp_bp
<
double
>(
template
int
mlp_bp
<
double
>(
double
*
X
,
double
*
X
,
...
@@ -1460,3 +1760,4 @@ template size_t get_mlp_bp_workspace_in_bytes<double>(
...
@@ -1460,3 +1760,4 @@ template size_t get_mlp_bp_workspace_in_bytes<double>(
int
batch_size
,
int
batch_size
,
int
num_layers
,
int
num_layers
,
const
int
*
output_features
);
const
int
*
output_features
);
csrc/multi_tensor_l2norm_kernel.cu
View file @
f79993d9
...
@@ -435,6 +435,11 @@ void multi_tensor_norm_out_cuda(
...
@@ -435,6 +435,11 @@ void multi_tensor_norm_out_cuda(
// I could get rid of these by hacking the functor + multi tensor harness with persistence
// I could get rid of these by hacking the functor + multi tensor harness with persistence
// logic, but keeping it simple for now
// logic, but keeping it simple for now
auto
ret
=
at
::
empty
({
1
},
output
.
options
());
auto
ret
=
at
::
empty
({
1
},
output
.
options
());
// Adding the following device guard since it happens sometimes that the
// tensors are on one device and the cuda stream is on another device which
// results in ILLEGAL MEM ACCESS error.
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
output
));
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
cleanup_v2
<<<
ntensors
,
512
,
0
,
stream
>>>
(
cleanup_v2
<<<
ntensors
,
512
,
0
,
stream
>>>
(
output
.
DATA_PTR
<
float
>
(),
output
.
DATA_PTR
<
float
>
(),
...
...
csrc/multi_tensor_l2norm_scale_kernel.cu
0 → 100644
View file @
f79993d9
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include "type_shim.h"
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512
#define ILP 4
template
<
typename
T
>
__device__
__forceinline__
bool
is_aligned
(
T
*
p
){
return
((
uint64_t
)
p
)
%
(
ILP
*
sizeof
(
T
))
==
0
;
}
template
<
typename
T
>
__device__
__forceinline__
void
load_store
(
T
*
dst
,
T
*
src
,
int
dst_offset
,
int
src_offset
){
typedef
typename
std
::
aligned_storage
<
ILP
*
sizeof
(
T
),
ILP
*
alignof
(
T
)
>::
type
LT
;
((
LT
*
)
dst
)[
dst_offset
]
=
((
LT
*
)
src
)[
src_offset
];
}
template
<
typename
in_t
,
typename
out_t
>
struct
L2NormScaleFunctor
{
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
2
>&
tl
,
float
*
output
,
float
*
output_per_tensor
,
float
scale
,
bool
per_tensor
,
int
max_chunks_per_tensor
)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
in_t
*
in
=
(
in_t
*
)
tl
.
addresses
[
0
][
tensor_loc
];
in
+=
chunk_idx
*
chunk_size
;
out_t
*
out
=
(
out_t
*
)
tl
.
addresses
[
1
][
tensor_loc
];
out
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
__shared__
float
s_vals
[
512
];
float
vals
[
ILP
];
// = {0}; // this probably works too but I want to be sure...
in_t
r_in
[
ILP
];
for
(
int
i
=
0
;
i
<
ILP
;
i
++
)
{
vals
[
i
]
=
0.
f
;
r_in
[
i
]
=
0
;
}
//bool finite = true;
out_t
r_out
[
ILP
];
// to make things simple, we put aligned case in a different code path
if
(
n
%
ILP
==
0
&&
chunk_size
%
ILP
==
0
&&
is_aligned
(
in
)
&&
is_aligned
(
out
))
{
for
(
int
i_start
=
threadIdx
.
x
;
i_start
*
ILP
<
n
&&
i_start
*
ILP
<
chunk_size
;
i_start
+=
blockDim
.
x
)
{
// load
load_store
(
r_in
,
in
,
0
,
i_start
);
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
float
next
=
static_cast
<
float
>
(
r_in
[
ii
]);
r_out
[
ii
]
=
next
*
scale
;
vals
[
ii
]
+=
next
*
next
;
//finite = finite && isfinite(r_in[ii]);
}
load_store
(
out
,
r_out
,
i_start
,
0
);
}
}
else
{
for
(
int
i_start
=
0
;
i_start
<
n
&&
i_start
<
chunk_size
;
i_start
+=
blockDim
.
x
*
ILP
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
r_in
[
ii
]
=
0
;
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
{
r_in
[
ii
]
=
in
[
i
];
float
next
=
static_cast
<
float
>
(
in
[
i
]);
vals
[
ii
]
+=
next
*
next
;
}
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
r_out
[
ii
]
=
static_cast
<
float
>
(
r_in
[
ii
])
*
scale
;
// finite = finite && isfinite(r_in[ii]);
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
out
[
i
]
=
r_out
[
ii
];
}
}
}
float
val
=
0.
f
;
for
(
int
i
=
0
;
i
<
ILP
;
i
++
)
val
+=
vals
[
i
];
float
final
=
reduce_block_into_lanes
(
s_vals
,
val
);
if
(
threadIdx
.
x
==
0
)
{
if
(
!
isfinite
(
final
))
*
noop_gmem
=
1
;
// Blindly fire off a write. These will race but that's ok.
output
[
blockIdx
.
x
]
+=
final
;
if
(
per_tensor
)
output_per_tensor
[(
tl
.
start_tensor_this_launch
+
tensor_loc
)
*
max_chunks_per_tensor
+
chunk_idx
]
=
final
;
}
}
};
// Probably better to template, but since we are not likely to support other norm
template
<
typename
x_t
>
struct
MaxNormFunctor
{
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
1
>&
tl
,
float
*
output
,
float
*
output_per_tensor
,
bool
per_tensor
,
int
max_chunks_per_tensor
)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
x_t
*
x
=
(
x_t
*
)
tl
.
addresses
[
0
][
tensor_loc
];
x
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
__shared__
float
s_vals
[
512
];
float
vals
[
ILP
];
// = {0}; // this probably works too but I want to be sure...
x_t
r_x
[
ILP
];
for
(
int
i
=
0
;
i
<
ILP
;
i
++
)
{
vals
[
i
]
=
0.
f
;
r_x
[
i
]
=
0
;
}
// to make things simple, we put aligned case in a different code path
if
(
n
%
ILP
==
0
&&
chunk_size
%
ILP
==
0
&&
is_aligned
(
x
))
{
for
(
int
i_start
=
threadIdx
.
x
;
i_start
*
ILP
<
n
&&
i_start
*
ILP
<
chunk_size
;
i_start
+=
blockDim
.
x
)
{
// load
load_store
(
r_x
,
x
,
0
,
i_start
);
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
float
next
=
static_cast
<
float
>
(
r_x
[
ii
]);
vals
[
ii
]
=
fmaxf
(
fabsf
(
vals
[
ii
]),
fabsf
(
next
));
}
}
}
else
{
for
(
int
i_start
=
0
;
i_start
<
n
&&
i_start
<
chunk_size
;
i_start
+=
blockDim
.
x
*
ILP
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
{
float
next
=
static_cast
<
float
>
(
x
[
i
]);
vals
[
ii
]
=
fmaxf
(
fabsf
(
vals
[
ii
]),
fabsf
(
next
));
}
}
}
}
float
val
=
0.
f
;
for
(
int
i
=
0
;
i
<
ILP
;
i
++
)
val
=
fmaxf
(
fabsf
(
val
),
fabsf
(
vals
[
i
]));
float
final
=
reduce_block_into_lanes_max_op
(
s_vals
,
val
);
if
(
threadIdx
.
x
==
0
)
{
if
(
!
isfinite
(
final
))
*
noop_gmem
=
1
;
// Blindly fire off a write. These will race but that's ok.
output
[
blockIdx
.
x
]
=
fmaxf
(
fabsf
(
output
[
blockIdx
.
x
]),
fabsf
(
final
));
if
(
per_tensor
)
output_per_tensor
[(
tl
.
start_tensor_this_launch
+
tensor_loc
)
*
max_chunks_per_tensor
+
chunk_idx
]
=
final
;
}
}
};
__global__
void
cleanup_v3
(
float
*
output
,
float
*
output_per_tensor
,
float
*
ret
,
float
*
ret_per_tensor
,
bool
per_tensor
,
int
max_chunks_per_tensor
)
{
__shared__
float
vals
[
512
];
if
(
blockIdx
.
x
==
0
)
{
float
val
=
0
;
if
(
threadIdx
.
x
<
320
)
val
=
output
[
threadIdx
.
x
];
float
final
=
reduce_block_into_lanes
(
vals
,
val
);
if
(
threadIdx
.
x
==
0
)
*
ret
=
sqrt
(
final
);
}
if
(
per_tensor
)
{
float
*
output_this_tensor
=
output_per_tensor
+
blockIdx
.
x
*
max_chunks_per_tensor
;
float
val
=
0
;
for
(
int
i
=
threadIdx
.
x
;
i
<
max_chunks_per_tensor
;
i
+=
blockDim
.
x
)
val
+=
output_this_tensor
[
i
];
float
final
=
reduce_block_into_lanes
(
vals
,
val
);
if
(
threadIdx
.
x
==
0
)
ret_per_tensor
[
blockIdx
.
x
]
=
sqrt
(
final
);
}
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
multi_tensor_l2norm_scale_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
float
scale
,
at
::
optional
<
bool
>
per_tensor_python
)
{
bool
per_tensor
=
per_tensor_python
.
has_value
()
?
per_tensor_python
.
value
()
:
false
;
auto
float_options
=
tensor_lists
[
0
][
0
].
options
().
dtype
(
at
::
kFloat
);
auto
output
=
at
::
zeros
({
320
},
float_options
);
at
::
Tensor
output_per_tensor
;
at
::
Tensor
ret_per_tensor
;
int
ntensors
=
tensor_lists
[
0
].
size
();
int
max_chunks_per_tensor
=
-
1
;
if
(
per_tensor
)
{
for
(
int
t
=
0
;
t
<
ntensors
;
t
++
)
{
int
max_chunks_this_tensor
=
(
tensor_lists
[
0
][
t
].
numel
()
+
chunk_size
-
1
)
/
chunk_size
;
if
(
max_chunks_this_tensor
>
max_chunks_per_tensor
)
max_chunks_per_tensor
=
max_chunks_this_tensor
;
}
output_per_tensor
=
at
::
zeros
({
ntensors
*
max_chunks_per_tensor
},
float_options
);
ret_per_tensor
=
at
::
empty
({
ntensors
},
float_options
);
}
else
{
ret_per_tensor
=
at
::
empty
({
0
},
float_options
);
}
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"multi_tensor_l2norm_scale_cuda"
,
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
1
][
0
].
scalar_type
(),
1
,
"multi_tensor_l2norm_scale_cuda"
,
multi_tensor_apply
<
2
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
L2NormScaleFunctor
<
scalar_t_0
,
scalar_t_1
>
(),
output
.
DATA_PTR
<
float
>
(),
per_tensor
?
output_per_tensor
.
DATA_PTR
<
float
>
()
:
nullptr
,
scale
,
per_tensor
,
max_chunks_per_tensor
);))
AT_CUDA_CHECK
(
cudaGetLastError
());
// AT_CUDA_CHECK(cudaDeviceSynchronize());
// This involves one more small kernel launches, but will be negligible end to end.
// I could get rid of these by hacking the functor + multi tensor harness with persistence
// logic, but keeping it simple for now
auto
ret
=
at
::
empty
({
1
},
output
.
options
());
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
output
));
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
cleanup_v3
<<<
per_tensor
?
ntensors
:
1
,
512
,
0
,
stream
>>>
(
output
.
DATA_PTR
<
float
>
(),
per_tensor
?
output_per_tensor
.
DATA_PTR
<
float
>
()
:
nullptr
,
ret
.
DATA_PTR
<
float
>
(),
per_tensor
?
ret_per_tensor
.
DATA_PTR
<
float
>
()
:
nullptr
,
per_tensor
,
max_chunks_per_tensor
);
return
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
(
ret
,
ret_per_tensor
);
}
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment