Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
ee7b19e7
Commit
ee7b19e7
authored
Apr 02, 2021
by
Mostofa Patwary
Browse files
Merge branch 'main' into main_dedup
parents
d413bd5f
f2d64c00
Changes
49
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1249 additions
and
715 deletions
+1249
-715
examples/create_embeddings.sh
examples/create_embeddings.sh
+32
-0
examples/evaluate_ict_zeroshot_nq.sh
examples/evaluate_ict_zeroshot_nq.sh
+36
-0
examples/pretrain_ict.sh
examples/pretrain_ict.sh
+44
-0
megatron/arguments.py
megatron/arguments.py
+68
-39
megatron/checkpointing.py
megatron/checkpointing.py
+63
-44
megatron/data/biencoder_dataset_utils.py
megatron/data/biencoder_dataset_utils.py
+208
-0
megatron/data/data_samplers.py
megatron/data/data_samplers.py
+14
-4
megatron/data/ict_dataset.py
megatron/data/ict_dataset.py
+20
-4
megatron/data/orqa_wiki_dataset.py
megatron/data/orqa_wiki_dataset.py
+205
-0
megatron/data/realm_index.py
megatron/data/realm_index.py
+88
-80
megatron/fused_kernels/__init__.py
megatron/fused_kernels/__init__.py
+70
-87
megatron/fused_kernels/layer_norm_cuda.cpp
megatron/fused_kernels/layer_norm_cuda.cpp
+32
-91
megatron/fused_kernels/layer_norm_cuda_kernel.cu
megatron/fused_kernels/layer_norm_cuda_kernel.cu
+33
-33
megatron/fused_kernels/scaled_masked_softmax.cpp
megatron/fused_kernels/scaled_masked_softmax.cpp
+9
-6
megatron/fused_kernels/scaled_masked_softmax.h
megatron/fused_kernels/scaled_masked_softmax.h
+86
-42
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
+28
-20
megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp
...tron/fused_kernels/scaled_upper_triang_masked_softmax.cpp
+10
-7
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
+109
-37
megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu
.../fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu
+25
-16
megatron/fused_kernels/type_shim.h
megatron/fused_kernels/type_shim.h
+69
-205
No files found.
examples/create_embeddings.sh
0 → 100644
View file @
ee7b19e7
#!/bin/bash
# Compute embeddings for each entry of a given dataset (e.g. Wikipedia)
RANK
=
0
WORLD_SIZE
=
1
# Wikipedia data can be downloaded from the following link:
# https://github.com/facebookresearch/DPR/blob/master/data/download_data.py
EVIDENCE_DATA_DIR
=
<Specify path of Wikipedia dataset>
EMBEDDING_PATH
=
<Specify path to store embeddings>
CHECKPOINT_PATH
=
<Specify path of pretrained ICT model>
python tools/create_doc_index.py
\
--num-layers
12
\
--hidden-size
768
\
--num-attention-heads
12
\
--tensor-model-parallel-size
1
\
--micro-batch-size
128
\
--checkpoint-activations
\
--seq-length
512
\
--retriever-seq-length
256
\
--max-position-embeddings
512
\
--load
${
CHECKPOINT_PATH
}
\
--evidence-data-path
${
EVIDENCE_DATA_DIR
}
\
--embedding-path
${
EMBEDDING_PATH
}
\
--indexer-log-interval
1000
\
--indexer-batch-size
128
\
--vocab-file
bert-vocab.txt
\
--num-workers
2
\
--fp16
examples/evaluate_ict_zeroshot_nq.sh
0 → 100644
View file @
ee7b19e7
#!/bin/bash
# Evaluate natural question test data given Wikipedia embeddings and pretrained
# ICT model
# Datasets can be downloaded from the following link:
# https://github.com/facebookresearch/DPR/blob/master/data/download_data.py
EVIDENCE_DATA_DIR
=
<Specify path of Wikipedia dataset>
EMBEDDING_PATH
=
<Specify path of the embeddings>
CHECKPOINT_PATH
=
<Specify path of pretrained ICT model>
QA_FILE
=
<Path of the natural question
test
dataset>
python tasks/main.py
\
--task
ICT-ZEROSHOT-NQ
\
--tokenizer-type
BertWordPieceLowerCase
\
--num-layers
12
\
--hidden-size
768
\
--num-attention-heads
12
\
--tensor-model-parallel-size
1
\
--micro-batch-size
128
\
--checkpoint-activations
\
--seq-length
512
\
--max-position-embeddings
512
\
--load
${
CHECKPOINT_PATH
}
\
--evidence-data-path
${
EVIDENCE_DATA_DIR
}
\
--embedding-path
${
EMBEDDING_PATH
}
\
--retriever-seq-length
256
\
--vocab-file
bert-vocab.txt
\
--qa-data-test
${
QA_FILE
}
\
--num-workers
2
\
--faiss-use-gpu
\
--retriever-report-topk-accuracies
1 5 20 100
\
--fp16
examples/pretrain_ict.sh
0 → 100755
View file @
ee7b19e7
#! /bin/bash
# Runs the "217M" parameter biencoder model for ICT retriever
RANK
=
0
WORLD_SIZE
=
1
PRETRAINED_BERT_PATH
=
<Specify path of pretrained BERT model>
TEXT_DATA_PATH
=
<Specify path and file prefix of the text data>
TITLE_DATA_PATH
=
<Specify path and file prefix
od
the titles>
CHECKPOINT_PATH
=
<Specify path>
python pretrain_ict.py
\
--num-layers
12
\
--hidden-size
768
\
--num-attention-heads
12
\
--tensor-model-parallel-size
1
\
--micro-batch-size
32
\
--seq-length
256
\
--max-position-embeddings
512
\
--train-iters
100000
\
--vocab-file
bert-vocab.txt
\
--tokenizer-type
BertWordPieceLowerCase
\
--DDP-impl
torch
\
--bert-load
${
PRETRAINED_BERT_PATH
}
\
--log-interval
100
\
--eval-interval
1000
\
--eval-iters
10
\
--retriever-report-topk-accuracies
1 5 10 20 100
\
--retriever-score-scaling
\
--load
$CHECKPOINT_PATH
\
--save
$CHECKPOINT_PATH
\
--data-path
${
TEXT_DATA_PATH
}
\
--titles-data-path
${
TITLE_DATA_PATH
}
\
--lr
0.0001
\
--lr-decay-style
linear
\
--weight-decay
1e-2
\
--clip-grad
1.0
\
--lr-warmup-fraction
0.01
\
--save-interval
4000
\
--exit-interval
8000
\
--query-in-block-prob
0.1
\
--fp16
megatron/arguments.py
View file @
ee7b19e7
...
@@ -19,7 +19,6 @@ import argparse
...
@@ -19,7 +19,6 @@ import argparse
import
os
import
os
import
torch
import
torch
from
megatron
import
fused_kernels
def
parse_args
(
extra_args_provider
=
None
,
defaults
=
{},
def
parse_args
(
extra_args_provider
=
None
,
defaults
=
{},
ignore_unknown_args
=
False
):
ignore_unknown_args
=
False
):
...
@@ -39,7 +38,7 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -39,7 +38,7 @@ def parse_args(extra_args_provider=None, defaults={},
parser
=
_add_validation_args
(
parser
)
parser
=
_add_validation_args
(
parser
)
parser
=
_add_data_args
(
parser
)
parser
=
_add_data_args
(
parser
)
parser
=
_add_autoresume_args
(
parser
)
parser
=
_add_autoresume_args
(
parser
)
parser
=
_add_
realm
_args
(
parser
)
parser
=
_add_
biencoder
_args
(
parser
)
parser
=
_add_vit_args
(
parser
)
parser
=
_add_vit_args
(
parser
)
parser
=
_add_logging_args
(
parser
)
parser
=
_add_logging_args
(
parser
)
...
@@ -70,7 +69,7 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -70,7 +69,7 @@ def parse_args(extra_args_provider=None, defaults={},
model_parallel_size
=
args
.
pipeline_model_parallel_size
*
\
model_parallel_size
=
args
.
pipeline_model_parallel_size
*
\
args
.
tensor_model_parallel_size
args
.
tensor_model_parallel_size
assert
args
.
world_size
%
model_parallel_size
==
0
,
'world size is not'
\
assert
args
.
world_size
%
model_parallel_size
==
0
,
'world size is not'
\
' divisible by tensor parallel size ({}) times pipeline paralle '
\
' divisible by tensor parallel size ({}) times pipeline paralle
l
'
\
'size ({})'
.
format
(
args
.
world_size
,
args
.
tensor_model_parallel_size
,
'size ({})'
.
format
(
args
.
world_size
,
args
.
tensor_model_parallel_size
,
args
.
pipeline_model_parallel_size
)
args
.
pipeline_model_parallel_size
)
args
.
data_parallel_size
=
args
.
world_size
//
model_parallel_size
args
.
data_parallel_size
=
args
.
world_size
//
model_parallel_size
...
@@ -116,15 +115,38 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -116,15 +115,38 @@ def parse_args(extra_args_provider=None, defaults={},
print
(
'setting global batch size to {}'
.
format
(
print
(
'setting global batch size to {}'
.
format
(
args
.
global_batch_size
),
flush
=
True
)
args
.
global_batch_size
),
flush
=
True
)
assert
args
.
global_batch_size
>
0
assert
args
.
global_batch_size
>
0
if
args
.
num_layers_per_virtual_pipeline_stage
is
not
None
:
assert
args
.
pipeline_model_parallel_size
>
2
,
\
'pipeline-model-parallel size should be greater than 2 with '
\
'interleaved schedule'
assert
args
.
num_layers
%
args
.
num_layers_per_virtual_pipeline_stage
==
0
,
\
'number of layers is not divisible by number of layers per virtual '
\
'pipeline stage'
args
.
virtual_pipeline_model_parallel_size
=
\
(
args
.
num_layers
//
args
.
pipeline_model_parallel_size
)
//
\
args
.
num_layers_per_virtual_pipeline_stage
else
:
args
.
virtual_pipeline_model_parallel_size
=
None
# Parameters dtype.
# Parameters dtype.
args
.
params_dtype
=
torch
.
float
args
.
params_dtype
=
torch
.
float
if
args
.
fp16
:
if
args
.
fp16
:
assert
not
args
.
bf16
args
.
params_dtype
=
torch
.
half
args
.
params_dtype
=
torch
.
half
if
args
.
bf16
:
assert
not
args
.
fp16
args
.
params_dtype
=
torch
.
bfloat16
if
args
.
rank
==
0
:
if
args
.
rank
==
0
:
print
(
'using {} for parameters ...'
.
format
(
args
.
params_dtype
),
print
(
'using {} for parameters ...'
.
format
(
args
.
params_dtype
),
flush
=
True
)
flush
=
True
)
# If we do accumulation and all-reduces in fp32, we need to have
# local DDP and we should set the use-contiguous-buffers-in-ddp.
if
args
.
accumulate_allreduce_grads_in_fp32
:
assert
args
.
DDP_impl
==
'local'
args
.
use_contiguous_buffers_in_ddp
=
True
if
args
.
dataloader_type
is
None
:
if
args
.
dataloader_type
is
None
:
args
.
dataloader_type
=
'single'
args
.
dataloader_type
=
'single'
...
@@ -195,39 +217,14 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -195,39 +217,14 @@ def parse_args(extra_args_provider=None, defaults={},
if
args
.
fp16_lm_cross_entropy
:
if
args
.
fp16_lm_cross_entropy
:
assert
args
.
fp16
,
'lm cross entropy in fp16 only support in fp16 mode.'
assert
args
.
fp16
,
'lm cross entropy in fp16 only support in fp16 mode.'
if
args
.
fp32_residual_connection
:
if
args
.
fp32_residual_connection
:
assert
args
.
fp16
,
\
assert
args
.
fp16
or
args
.
bf16
,
\
'residual connection in fp32 only supported when using fp16.'
'residual connection in fp32 only supported when using fp16
or bf16
.'
# Activation checkpointing.
# Activation checkpointing.
if
args
.
distribute_checkpointed_activations
:
if
args
.
distribute_checkpointed_activations
:
assert
args
.
checkpoint_activations
,
\
assert
args
.
checkpoint_activations
,
\
'for distribute-checkpointed-activations to work you '
\
'for distribute-checkpointed-activations to work you '
\
'need to enable checkpoint-activations'
'need to enable checkpoint-activations'
# custom kernel constraints check
seq_len
=
args
.
seq_length
attn_batch_size
=
\
(
args
.
num_attention_heads
/
args
.
tensor_model_parallel_size
)
*
\
args
.
micro_batch_size
# constraints on sequence length and attn_batch_size to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint
=
seq_len
>
16
and
seq_len
<=
2048
and
\
seq_len
%
4
==
0
and
attn_batch_size
%
4
==
0
if
args
.
fp16
and
custom_kernel_constraint
and
args
.
masked_softmax_fusion
:
print
(
'WARNING: constraints for invoking optimized'
' fused softmax kernel are not met. We default back to unfused'
' kernel invocations.'
)
# Load scaled_masked_softmax_fusion_kernels
if
args
.
masked_softmax_fusion
:
fused_kernels
.
load_scaled_upper_triang_masked_softmax_fusion_kernel
()
fused_kernels
.
load_scaled_masked_softmax_fusion_kernel
()
# Load mixed precision fused layer norm.
if
args
.
fp32_residual_connection
:
fused_kernels
.
load_fused_mix_prec_layer_norm_kernel
()
_print_args
(
args
)
_print_args
(
args
)
return
args
return
args
...
@@ -299,6 +296,8 @@ def _add_logging_args(parser):
...
@@ -299,6 +296,8 @@ def _add_logging_args(parser):
group
.
add_argument
(
'--log-params-norm'
,
action
=
'store_true'
,
group
.
add_argument
(
'--log-params-norm'
,
action
=
'store_true'
,
help
=
'If set, calculate and log parameters norm.'
)
help
=
'If set, calculate and log parameters norm.'
)
group
.
add_argument
(
'--log-num-zeros-in-grad'
,
action
=
'store_true'
,
help
=
'If set, calculate and log the number of zeros in gradient.'
)
group
.
add_argument
(
'--tensorboard-log-interval'
,
type
=
int
,
default
=
1
,
group
.
add_argument
(
'--tensorboard-log-interval'
,
type
=
int
,
default
=
1
,
help
=
'Report to tensorboard interval.'
)
help
=
'Report to tensorboard interval.'
)
group
.
add_argument
(
'--tensorboard-queue-size'
,
type
=
int
,
default
=
1000
,
group
.
add_argument
(
'--tensorboard-queue-size'
,
type
=
int
,
default
=
1000
,
...
@@ -517,6 +516,8 @@ def _add_mixed_precision_args(parser):
...
@@ -517,6 +516,8 @@ def _add_mixed_precision_args(parser):
group
.
add_argument
(
'--fp16'
,
action
=
'store_true'
,
group
.
add_argument
(
'--fp16'
,
action
=
'store_true'
,
help
=
'Run model in fp16 mode.'
)
help
=
'Run model in fp16 mode.'
)
group
.
add_argument
(
'--bf16'
,
action
=
'store_true'
,
help
=
'Run model in bfloat16 mode.'
)
group
.
add_argument
(
'--loss-scale'
,
type
=
float
,
default
=
None
,
group
.
add_argument
(
'--loss-scale'
,
type
=
float
,
default
=
None
,
help
=
'Static loss scaling, positive power of 2 '
help
=
'Static loss scaling, positive power of 2 '
'values can improve fp16 convergence. If None, dynamic'
'values can improve fp16 convergence. If None, dynamic'
...
@@ -538,8 +539,9 @@ def _add_mixed_precision_args(parser):
...
@@ -538,8 +539,9 @@ def _add_mixed_precision_args(parser):
help
=
'Run attention masking and softmax in fp32. '
help
=
'Run attention masking and softmax in fp32. '
'This flag is ignored unless '
'This flag is ignored unless '
'--no-query-key-layer-scaling is specified.'
)
'--no-query-key-layer-scaling is specified.'
)
group
.
add_argument
(
'--fp32-allreduce'
,
action
=
'store_true'
,
group
.
add_argument
(
'--accumulate-allreduce-grads-in-fp32'
,
help
=
'All-reduce in fp32'
)
action
=
'store_true'
,
help
=
'Gradient accumulation and all-reduce in fp32.'
)
group
.
add_argument
(
'--fp16-lm-cross-entropy'
,
action
=
'store_true'
,
group
.
add_argument
(
'--fp16-lm-cross-entropy'
,
action
=
'store_true'
,
help
=
'Move the cross entropy unreduced loss calculation'
help
=
'Move the cross entropy unreduced loss calculation'
'for lm head to fp16.'
)
'for lm head to fp16.'
)
...
@@ -557,6 +559,8 @@ def _add_distributed_args(parser):
...
@@ -557,6 +559,8 @@ def _add_distributed_args(parser):
group
.
add_argument
(
'--model-parallel-size'
,
type
=
int
,
default
=
None
,
group
.
add_argument
(
'--model-parallel-size'
,
type
=
int
,
default
=
None
,
help
=
'Old model parallel argument, do not use. Use '
help
=
'Old model parallel argument, do not use. Use '
'--tensor-model-parallel-size instead.'
)
'--tensor-model-parallel-size instead.'
)
group
.
add_argument
(
'--num-layers-per-virtual-pipeline-stage'
,
type
=
int
,
default
=
None
,
help
=
'Number of layers per virtual pipeline stage'
)
group
.
add_argument
(
'--distributed-backend'
,
default
=
'nccl'
,
group
.
add_argument
(
'--distributed-backend'
,
default
=
'nccl'
,
choices
=
[
'nccl'
,
'gloo'
],
choices
=
[
'nccl'
,
'gloo'
],
help
=
'Which backend to use for distributed training.'
)
help
=
'Which backend to use for distributed training.'
)
...
@@ -564,6 +568,12 @@ def _add_distributed_args(parser):
...
@@ -564,6 +568,12 @@ def _add_distributed_args(parser):
choices
=
[
'local'
,
'torch'
],
choices
=
[
'local'
,
'torch'
],
help
=
'which DistributedDataParallel implementation '
help
=
'which DistributedDataParallel implementation '
'to use.'
)
'to use.'
)
group
.
add_argument
(
'--use-contiguous-buffers-in-ddp'
,
action
=
'store_true'
,
help
=
'If set, use contiguous buffer in DDP. Note that '
'this option only works woth local DDP.'
)
group
.
add_argument
(
'--no-scatter-gather-tensors-in-pipeline'
,
action
=
'store_false'
,
help
=
'Use scatter/gather to optimize communication of tensors in pipeline'
,
dest
=
'scatter_gather_tensors_in_pipeline'
)
group
.
add_argument
(
'--local_rank'
,
type
=
int
,
default
=
None
,
group
.
add_argument
(
'--local_rank'
,
type
=
int
,
default
=
None
,
help
=
'local rank passed from distributed launcher.'
)
help
=
'local rank passed from distributed launcher.'
)
group
.
add_argument
(
'--lazy-mpu-init'
,
type
=
bool
,
required
=
False
,
group
.
add_argument
(
'--lazy-mpu-init'
,
type
=
bool
,
required
=
False
,
...
@@ -615,6 +625,12 @@ def _add_data_args(parser):
...
@@ -615,6 +625,12 @@ def _add_data_args(parser):
'This should be exclusive of --seq-length'
)
'This should be exclusive of --seq-length'
)
group
.
add_argument
(
'--decoder-seq-length'
,
type
=
int
,
default
=
None
,
group
.
add_argument
(
'--decoder-seq-length'
,
type
=
int
,
default
=
None
,
help
=
"Maximum decoder sequence length to process."
)
help
=
"Maximum decoder sequence length to process."
)
group
.
add_argument
(
'--retriever-seq-length'
,
type
=
int
,
default
=
256
,
help
=
'Maximum sequence length for the biencoder model '
' for retriever'
)
group
.
add_argument
(
'--sample-rate'
,
type
=
float
,
default
=
1.0
,
help
=
'sample rate for training data. Supposed to be 0 '
' < sample_rate < 1'
)
group
.
add_argument
(
'--mask-prob'
,
type
=
float
,
default
=
0.15
,
group
.
add_argument
(
'--mask-prob'
,
type
=
float
,
default
=
0.15
,
help
=
'Probability of replacing a token with mask.'
)
help
=
'Probability of replacing a token with mask.'
)
group
.
add_argument
(
'--short-seq-prob'
,
type
=
float
,
default
=
0.1
,
group
.
add_argument
(
'--short-seq-prob'
,
type
=
float
,
default
=
0.1
,
...
@@ -655,13 +671,19 @@ def _add_autoresume_args(parser):
...
@@ -655,13 +671,19 @@ def _add_autoresume_args(parser):
return
parser
return
parser
def
_add_
realm
_args
(
parser
):
def
_add_
biencoder
_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'
realm
'
)
group
=
parser
.
add_argument_group
(
title
=
'
biencoder
'
)
# network size
# network size
group
.
add_argument
(
'--ict-head-size'
,
type
=
int
,
default
=
None
,
group
.
add_argument
(
'--ict-head-size'
,
type
=
int
,
default
=
None
,
help
=
'Size of block embeddings to be used in ICT and '
help
=
'Size of block embeddings to be used in ICT and '
'REALM (paper default: 128)'
)
'REALM (paper default: 128)'
)
group
.
add_argument
(
'--biencoder-projection-dim'
,
type
=
int
,
default
=
0
,
help
=
'Size of projection head used in biencoder (paper'
' default: 128)'
)
group
.
add_argument
(
'--biencoder-shared-query-context-model'
,
action
=
'store_true'
,
help
=
'Whether to share the parameters of the query '
'and context models or not'
)
# checkpointing
# checkpointing
group
.
add_argument
(
'--ict-load'
,
type
=
str
,
default
=
None
,
group
.
add_argument
(
'--ict-load'
,
type
=
str
,
default
=
None
,
...
@@ -678,16 +700,23 @@ def _add_realm_args(parser):
...
@@ -678,16 +700,23 @@ def _add_realm_args(parser):
'ICT dataset'
)
'ICT dataset'
)
group
.
add_argument
(
'--use-one-sent-docs'
,
action
=
'store_true'
,
group
.
add_argument
(
'--use-one-sent-docs'
,
action
=
'store_true'
,
help
=
'Whether to use one sentence documents in ICT'
)
help
=
'Whether to use one sentence documents in ICT'
)
group
.
add_argument
(
'--evidence-data-path'
,
type
=
str
,
default
=
None
,
help
=
'Path to Wikipedia Evidence frm DPR paper'
)
# training
# training
group
.
add_argument
(
'--report-topk-accuracies'
,
nargs
=
'+'
,
default
=
[],
group
.
add_argument
(
'--retriever-report-topk-accuracies'
,
nargs
=
'+'
,
type
=
int
,
help
=
"Which top-k accuracies to report (e.g. '1 5 20')"
)
default
=
[],
help
=
"Which top-k accuracies to report "
"(e.g. '1 5 20')"
)
group
.
add_argument
(
'--retriever-score-scaling'
,
action
=
'store_true'
,
help
=
'Whether to scale retriever scores by inverse '
'square root of hidden size'
)
# faiss index
# faiss index
group
.
add_argument
(
'--faiss-use-gpu'
,
action
=
'store_true'
,
help
=
'Whether create the FaissMIPSIndex on GPU'
)
group
.
add_argument
(
'--block-data-path'
,
type
=
str
,
default
=
None
,
group
.
add_argument
(
'--block-data-path'
,
type
=
str
,
default
=
None
,
help
=
'Where to save/load BlockData to/from'
)
help
=
'Where to save/load BlockData to/from'
)
group
.
add_argument
(
'--embedding-path'
,
type
=
str
,
default
=
None
,
help
=
'Where to save/load Open-Retrieval Embedding'
' data to/from'
)
# indexer
# indexer
group
.
add_argument
(
'--indexer-batch-size'
,
type
=
int
,
default
=
128
,
group
.
add_argument
(
'--indexer-batch-size'
,
type
=
int
,
default
=
128
,
...
...
megatron/checkpointing.py
View file @
ee7b19e7
...
@@ -21,12 +21,12 @@ import sys
...
@@ -21,12 +21,12 @@ import sys
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
torch.nn.parallel
import
DistributedDataParallel
as
torchDDP
from
megatron
import
(
get_args
,
from
megatron
import
(
get_args
,
mpu
,
mpu
,
print_rank_0
,
print_rank_0
,
update_num_microbatches
)
update_num_microbatches
,
utils
)
_CHECKPOINT_VERSION
=
None
_CHECKPOINT_VERSION
=
None
...
@@ -111,8 +111,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
...
@@ -111,8 +111,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
args
=
get_args
()
args
=
get_args
()
# Only rank zero of the data parallel writes to the disk.
# Only rank zero of the data parallel writes to the disk.
if
isinstance
(
model
,
torchDDP
):
model
=
utils
.
unwrap_model
(
model
)
model
=
model
.
module
print_rank_0
(
'saving checkpoint at iteration {:7d} to {}'
.
format
(
print_rank_0
(
'saving checkpoint at iteration {:7d} to {}'
.
format
(
iteration
,
args
.
save
))
iteration
,
args
.
save
))
...
@@ -124,7 +123,12 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
...
@@ -124,7 +123,12 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
state_dict
[
'args'
]
=
args
state_dict
[
'args'
]
=
args
state_dict
[
'checkpoint_version'
]
=
3.0
state_dict
[
'checkpoint_version'
]
=
3.0
state_dict
[
'iteration'
]
=
iteration
state_dict
[
'iteration'
]
=
iteration
state_dict
[
'model'
]
=
model
.
state_dict_for_save_checkpoint
()
if
len
(
model
)
==
1
:
state_dict
[
'model'
]
=
model
[
0
].
state_dict_for_save_checkpoint
()
else
:
for
i
in
range
(
len
(
model
)):
mpu
.
set_virtual_pipeline_model_parallel_rank
(
i
)
state_dict
[
'model%d'
%
i
]
=
model
[
i
].
state_dict_for_save_checkpoint
()
# Optimizer stuff.
# Optimizer stuff.
if
not
args
.
no_save_optim
:
if
not
args
.
no_save_optim
:
...
@@ -202,6 +206,33 @@ def _transpose_first_dim(t, num_splits, num_splits_first, model):
...
@@ -202,6 +206,33 @@ def _transpose_first_dim(t, num_splits, num_splits_first, model):
return
t
return
t
def
fix_query_key_value_ordering
(
model
,
checkpoint_version
):
"""Fix up query/key/value matrix ordering if checkpoint
version is smaller than 2.0
"""
if
checkpoint_version
<
2.0
:
for
name
,
param
in
model
.
named_parameters
():
if
name
.
endswith
((
'.query_key_value.weight'
,
'.query_key_value.bias'
)):
if
checkpoint_version
==
0
:
fixed_param
=
_transpose_first_dim
(
param
.
data
,
3
,
True
,
model
)
elif
checkpoint_version
==
1.0
:
fixed_param
=
_transpose_first_dim
(
param
.
data
,
3
,
False
,
model
)
else
:
print_rank_0
(
f
"Invalid checkpoint version
{
checkpoint_version
}
."
)
sys
.
exit
()
param
.
data
.
copy_
(
fixed_param
)
if
name
.
endswith
((
'.key_value.weight'
,
'.key_value.bias'
)):
if
checkpoint_version
==
0
:
fixed_param
=
_transpose_first_dim
(
param
.
data
,
2
,
True
,
model
)
elif
checkpoint_version
==
1.0
:
fixed_param
=
_transpose_first_dim
(
param
.
data
,
2
,
False
,
model
)
else
:
print_rank_0
(
f
"Invalid checkpoint version
{
checkpoint_version
}
."
)
sys
.
exit
()
param
.
data
.
copy_
(
fixed_param
)
print_rank_0
(
" succesfully fixed query-key-values ordering for"
" checkpoint version {}"
.
format
(
checkpoint_version
))
def
load_checkpoint
(
model
,
optimizer
,
lr_scheduler
,
load_arg
=
'load'
,
strict
=
True
):
def
load_checkpoint
(
model
,
optimizer
,
lr_scheduler
,
load_arg
=
'load'
,
strict
=
True
):
"""Load a model checkpoint and return the iteration.
"""Load a model checkpoint and return the iteration.
strict (bool): whether to strictly enforce that the keys in
strict (bool): whether to strictly enforce that the keys in
...
@@ -211,8 +242,8 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
...
@@ -211,8 +242,8 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
args
=
get_args
()
args
=
get_args
()
load_dir
=
getattr
(
args
,
load_arg
)
load_dir
=
getattr
(
args
,
load_arg
)
if
isinstance
(
model
,
torchDDP
):
model
=
utils
.
unwrap_model
(
model
)
model
=
model
.
module
# Read the tracker file and set the iteration.
# Read the tracker file and set the iteration.
tracker_filename
=
get_checkpoint_tracker_filename
(
load_dir
)
tracker_filename
=
get_checkpoint_tracker_filename
(
load_dir
)
...
@@ -297,30 +328,17 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
...
@@ -297,30 +328,17 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
print_rank_0
(
'could not find arguments in the checkpoint ...'
)
print_rank_0
(
'could not find arguments in the checkpoint ...'
)
# Model.
# Model.
model
.
load_state_dict
(
state_dict
[
'model'
],
strict
=
strict
)
if
len
(
model
)
==
1
:
model
[
0
].
load_state_dict
(
state_dict
[
'model'
],
strict
=
strict
)
else
:
for
i
in
range
(
len
(
model
)):
mpu
.
set_virtual_pipeline_model_parallel_rank
(
i
)
model
[
i
].
load_state_dict
(
state_dict
[
'model%d'
%
i
],
strict
=
strict
)
# Fix up query/key/value matrix ordering
# Fix up query/key/value matrix ordering if needed
if
get_checkpoint_version
()
<
2.0
:
checkpoint_version
=
get_checkpoint_version
()
checkpoint_version
=
get_checkpoint_version
()
print_rank_0
(
f
' checkpoint version
{
checkpoint_version
}
'
)
for
name
,
param
in
model
.
named_parameters
():
fix_query_key_value_ordering
(
model
,
checkpoint_version
)
if
name
.
endswith
((
'.query_key_value.weight'
,
'.query_key_value.bias'
)):
if
checkpoint_version
==
0
:
fixed_param
=
_transpose_first_dim
(
param
.
data
,
3
,
True
,
model
)
elif
checkpoint_version
==
1.0
:
fixed_param
=
_transpose_first_dim
(
param
.
data
,
3
,
False
,
model
)
else
:
print_rank_0
(
f
"Invalid checkpoint version
{
checkpoint_version
}
."
)
sys
.
exit
()
param
.
data
.
copy_
(
fixed_param
)
if
name
.
endswith
((
'.key_value.weight'
,
'.key_value.bias'
)):
if
checkpoint_version
==
0
:
fixed_param
=
_transpose_first_dim
(
param
.
data
,
2
,
True
,
model
)
elif
checkpoint_version
==
1.0
:
fixed_param
=
_transpose_first_dim
(
param
.
data
,
2
,
False
,
model
)
else
:
print_rank_0
(
f
"Invalid checkpoint version
{
checkpoint_version
}
."
)
sys
.
exit
()
param
.
data
.
copy_
(
fixed_param
)
# Optimizer.
# Optimizer.
if
not
release
and
not
args
.
finetune
and
not
args
.
no_load_optim
:
if
not
release
and
not
args
.
finetune
and
not
args
.
no_load_optim
:
...
@@ -365,41 +383,42 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
...
@@ -365,41 +383,42 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
return
iteration
return
iteration
def
load_ict_checkpoint
(
model
,
only_query_model
=
False
,
only_block_model
=
False
,
from_realm_chkpt
=
False
):
def
load_biencoder_checkpoint
(
model
,
only_query_model
=
False
,
"""selectively load ICT models for indexing/retrieving from ICT or REALM checkpoints"""
only_context_model
=
False
,
custom_load_path
=
None
):
"""
selectively load retrieval models for indexing/retrieving
from saved checkpoints
"""
args
=
get_args
()
args
=
get_args
()
if
isinstance
(
model
,
torchDDP
):
model
=
utils
.
unwrap_model
(
model
)
model
=
model
.
module
load_path
=
args
.
load
if
from_realm_chkpt
else
args
.
ict_
load
load_path
=
custom_load_path
if
custom_load_path
is
not
None
else
args
.
load
tracker_filename
=
get_checkpoint_tracker_filename
(
load_path
)
tracker_filename
=
get_checkpoint_tracker_filename
(
load_path
)
with
open
(
tracker_filename
,
'r'
)
as
f
:
with
open
(
tracker_filename
,
'r'
)
as
f
:
iteration
=
int
(
f
.
read
().
strip
())
iteration
=
int
(
f
.
read
().
strip
())
# assert iteration > 0
checkpoint_name
=
get_checkpoint_name
(
load_path
,
iteration
,
False
)
checkpoint_name
=
get_checkpoint_name
(
load_path
,
iteration
,
False
)
if
mpu
.
get_data_parallel_rank
()
==
0
:
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
'global rank {} is loading checkpoint {}'
.
format
(
print
(
'global rank {} is loading checkpoint {}'
.
format
(
torch
.
distributed
.
get_rank
(),
checkpoint_name
))
torch
.
distributed
.
get_rank
(),
checkpoint_name
))
state_dict
=
torch
.
load
(
checkpoint_name
,
map_location
=
'cpu'
)
state_dict
=
torch
.
load
(
checkpoint_name
,
map_location
=
'cpu'
)
ict_state_dict
=
state_dict
[
'model'
]
ret_state_dict
=
state_dict
[
'model'
]
if
from_realm_chkpt
and
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
" loading ICT state dict from REALM"
,
flush
=
True
)
ict_state_dict
=
ict_state_dict
[
'retriever'
][
'ict_model'
]
if
only_query_model
:
if
only_query_model
:
ic
t_state_dict
.
pop
(
'context_model'
)
re
t_state_dict
.
pop
(
'context_model'
)
if
only_
block
_model
:
if
only_
context
_model
:
ic
t_state_dict
.
pop
(
'que
stion
_model'
)
re
t_state_dict
.
pop
(
'que
ry
_model'
)
model
.
load_state_dict
(
ict_state_dict
)
assert
len
(
model
)
==
1
model
[
0
].
load_state_dict
(
ret_state_dict
)
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
if
mpu
.
get_data_parallel_rank
()
==
0
:
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
' successfully loaded {}'
.
format
(
checkpoint_name
))
print
(
' successfully loaded {}'
.
format
(
checkpoint_name
))
return
model
return
model
megatron/data/biencoder_dataset_utils.py
0 → 100644
View file @
ee7b19e7
import
os
import
time
import
numpy
as
np
import
torch
from
megatron
import
get_args
,
get_tokenizer
,
mpu
,
print_rank_0
from
megatron.data.dataset_utils
import
create_masked_lm_predictions
,
\
pad_and_convert_to_numpy
from
megatron.data.data_samplers
import
MegatronPretrainingSampler
def
make_attention_mask
(
source_block
,
target_block
):
"""
Returns a 2-dimensional (2-D) attention mask
:param source_block: 1-D array
:param target_block: 1-D array
"""
mask
=
(
target_block
[
None
,
:]
>=
1
)
*
(
source_block
[:,
None
]
>=
1
)
mask
=
mask
.
astype
(
np
.
int64
)
# (source_length, target_length)
return
mask
def
get_one_epoch_dataloader
(
dataset
,
micro_batch_size
=
None
):
"""Specifically one epoch to be used in an indexing job."""
args
=
get_args
()
if
micro_batch_size
is
None
:
micro_batch_size
=
args
.
micro_batch_size
num_workers
=
args
.
num_workers
# Use megatron's sampler with consumed samples set to 0 as
# this is only for evaluation and don't intend to resume half way.
# Also, set the drop last to false as don't intend to remove
# the last batch
batch_sampler
=
MegatronPretrainingSampler
(
total_samples
=
len
(
dataset
),
consumed_samples
=
0
,
micro_batch_size
=
args
.
micro_batch_size
,
data_parallel_rank
=
mpu
.
get_data_parallel_rank
(),
data_parallel_size
=
mpu
.
get_data_parallel_world_size
(),
drop_last
=
False
)
return
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_sampler
=
batch_sampler
,
num_workers
=
num_workers
,
pin_memory
=
True
)
def
get_ict_batch
(
data_iterator
):
# Items and their type.
keys
=
[
'query_tokens'
,
'query_mask'
,
'context_tokens'
,
'context_mask'
,
'block_data'
]
datatype
=
torch
.
int64
# Broadcast data.
if
data_iterator
is
None
:
data
=
None
else
:
data
=
next
(
data_iterator
)
data_b
=
mpu
.
broadcast_data
(
keys
,
data
,
datatype
)
# Unpack.
query_tokens
=
data_b
[
'query_tokens'
].
long
()
query_mask
=
data_b
[
'query_mask'
]
<
0.5
context_tokens
=
data_b
[
'context_tokens'
].
long
()
context_mask
=
data_b
[
'context_mask'
]
<
0.5
block_indices
=
data_b
[
'block_data'
].
long
()
return
query_tokens
,
query_mask
,
\
context_tokens
,
context_mask
,
block_indices
def
join_str_list
(
str_list
):
"""Join a list of strings, handling spaces appropriately"""
result
=
""
for
s
in
str_list
:
if
s
.
startswith
(
"##"
):
result
+=
s
[
2
:]
else
:
result
+=
" "
+
s
return
result
class
BlockSampleData
(
object
):
"""A struct for fully describing a fixed-size block of data as used in REALM
:param start_idx: for first sentence of the block
:param end_idx: for last sentence of the block (may be partially truncated in sample construction)
:param doc_idx: the index of the document from which the block comes in the original indexed dataset
:param block_idx: a unique integer identifier given to every block.
"""
def
__init__
(
self
,
start_idx
,
end_idx
,
doc_idx
,
block_idx
):
self
.
start_idx
=
start_idx
self
.
end_idx
=
end_idx
self
.
doc_idx
=
doc_idx
self
.
block_idx
=
block_idx
def
as_array
(
self
):
return
np
.
array
([
self
.
start_idx
,
self
.
end_idx
,
self
.
doc_idx
,
self
.
block_idx
]).
astype
(
np
.
int64
)
def
as_tuple
(
self
):
return
self
.
start_idx
,
self
.
end_idx
,
self
.
doc_idx
,
self
.
block_idx
class
BlockSamplesMapping
(
object
):
def
__init__
(
self
,
mapping_array
):
# make sure that the array is compatible with BlockSampleData
assert
mapping_array
.
shape
[
1
]
==
4
self
.
mapping_array
=
mapping_array
def
__len__
(
self
):
return
self
.
mapping_array
.
shape
[
0
]
def
__getitem__
(
self
,
idx
):
"""Get the data associated with an indexed sample."""
sample_data
=
BlockSampleData
(
*
self
.
mapping_array
[
idx
])
return
sample_data
def
get_block_samples_mapping
(
block_dataset
,
title_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
,
max_seq_length
,
seed
,
name
,
use_one_sent_docs
=
False
):
"""Get samples mapping for a dataset over fixed size blocks. This function also requires
a dataset of the titles for the source documents since their lengths must be taken into account.
:return: samples_mapping (BlockSamplesMapping)
"""
if
not
num_epochs
:
if
not
max_num_samples
:
raise
ValueError
(
"Need to specify either max_num_samples "
"or num_epochs"
)
num_epochs
=
np
.
iinfo
(
np
.
int32
).
max
-
1
if
not
max_num_samples
:
max_num_samples
=
np
.
iinfo
(
np
.
int64
).
max
-
1
# Filename of the index mapping
indexmap_filename
=
data_prefix
indexmap_filename
+=
'_{}_indexmap'
.
format
(
name
)
if
num_epochs
!=
(
np
.
iinfo
(
np
.
int32
).
max
-
1
):
indexmap_filename
+=
'_{}ep'
.
format
(
num_epochs
)
if
max_num_samples
!=
(
np
.
iinfo
(
np
.
int64
).
max
-
1
):
indexmap_filename
+=
'_{}mns'
.
format
(
max_num_samples
)
indexmap_filename
+=
'_{}msl'
.
format
(
max_seq_length
)
indexmap_filename
+=
'_{}s'
.
format
(
seed
)
if
use_one_sent_docs
:
indexmap_filename
+=
'_1sentok'
indexmap_filename
+=
'.npy'
# Build the indexed mapping if not exist.
if
mpu
.
get_data_parallel_rank
()
==
0
and
\
not
os
.
path
.
isfile
(
indexmap_filename
):
print
(
' > WARNING: could not find index map file {}, building '
'the indices on rank 0 ...'
.
format
(
indexmap_filename
))
# Make sure the types match the helpers input types.
assert
block_dataset
.
doc_idx
.
dtype
==
np
.
int64
assert
block_dataset
.
sizes
.
dtype
==
np
.
int32
# Build samples mapping
verbose
=
torch
.
distributed
.
get_rank
()
==
0
start_time
=
time
.
time
()
print_rank_0
(
' > building samples index mapping for {} ...'
.
format
(
name
))
from
megatron.data
import
helpers
mapping_array
=
helpers
.
build_blocks_mapping
(
block_dataset
.
doc_idx
,
block_dataset
.
sizes
,
title_dataset
.
sizes
,
num_epochs
,
max_num_samples
,
max_seq_length
-
3
,
# account for added tokens
seed
,
verbose
,
use_one_sent_docs
)
print_rank_0
(
' > done building samples index mapping'
)
np
.
save
(
indexmap_filename
,
mapping_array
,
allow_pickle
=
True
)
print_rank_0
(
' > saved the index mapping in {}'
.
format
(
indexmap_filename
))
# Make sure all the ranks have built the mapping
print_rank_0
(
' > elapsed time to build and save samples mapping '
'(seconds): {:4f}'
.
format
(
time
.
time
()
-
start_time
))
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts
=
torch
.
cuda
.
LongTensor
([
1
])
torch
.
distributed
.
all_reduce
(
counts
,
group
=
mpu
.
get_data_parallel_group
())
assert
counts
[
0
].
item
()
==
torch
.
distributed
.
get_world_size
(
group
=
mpu
.
get_data_parallel_group
())
# Load indexed dataset.
print_rank_0
(
' > loading indexed mapping from {}'
.
format
(
indexmap_filename
))
start_time
=
time
.
time
()
mapping_array
=
np
.
load
(
indexmap_filename
,
allow_pickle
=
True
,
mmap_mode
=
'r'
)
samples_mapping
=
BlockSamplesMapping
(
mapping_array
)
print_rank_0
(
' loaded indexed file in {:3.3f} seconds'
.
format
(
time
.
time
()
-
start_time
))
print_rank_0
(
' total number of samples: {}'
.
format
(
mapping_array
.
shape
[
0
]))
return
samples_mapping
megatron/data/data_samplers.py
View file @
ee7b19e7
...
@@ -57,7 +57,7 @@ def build_pretraining_data_loader(dataset, consumed_samples):
...
@@ -57,7 +57,7 @@ def build_pretraining_data_loader(dataset, consumed_samples):
class
MegatronPretrainingSampler
:
class
MegatronPretrainingSampler
:
def
__init__
(
self
,
total_samples
,
consumed_samples
,
micro_batch_size
,
def
__init__
(
self
,
total_samples
,
consumed_samples
,
micro_batch_size
,
data_parallel_rank
,
data_parallel_size
):
data_parallel_rank
,
data_parallel_size
,
drop_last
=
True
):
# Keep a copy of input params for later use.
# Keep a copy of input params for later use.
self
.
total_samples
=
total_samples
self
.
total_samples
=
total_samples
self
.
consumed_samples
=
consumed_samples
self
.
consumed_samples
=
consumed_samples
...
@@ -65,6 +65,7 @@ class MegatronPretrainingSampler:
...
@@ -65,6 +65,7 @@ class MegatronPretrainingSampler:
self
.
data_parallel_rank
=
data_parallel_rank
self
.
data_parallel_rank
=
data_parallel_rank
self
.
micro_batch_times_data_parallel_size
=
\
self
.
micro_batch_times_data_parallel_size
=
\
self
.
micro_batch_size
*
data_parallel_size
self
.
micro_batch_size
*
data_parallel_size
self
.
drop_last
=
drop_last
# Sanity checks.
# Sanity checks.
assert
self
.
total_samples
>
0
,
\
assert
self
.
total_samples
>
0
,
\
...
@@ -81,17 +82,26 @@ class MegatronPretrainingSampler:
...
@@ -81,17 +82,26 @@ class MegatronPretrainingSampler:
def
__len__
(
self
):
def
__len__
(
self
):
return
self
.
total_samples
return
self
.
total_samples
def
get_start_end_idx
(
self
):
start_idx
=
self
.
data_parallel_rank
*
self
.
micro_batch_size
end_idx
=
start_idx
+
self
.
micro_batch_size
return
start_idx
,
end_idx
def
__iter__
(
self
):
def
__iter__
(
self
):
batch
=
[]
batch
=
[]
# Last batch
if not complete will be dropped.
# Last batch
will be dropped if drop_last is not set False
for
idx
in
range
(
self
.
consumed_samples
,
self
.
total_samples
):
for
idx
in
range
(
self
.
consumed_samples
,
self
.
total_samples
):
batch
.
append
(
idx
)
batch
.
append
(
idx
)
if
len
(
batch
)
==
self
.
micro_batch_times_data_parallel_size
:
if
len
(
batch
)
==
self
.
micro_batch_times_data_parallel_size
:
start_idx
=
self
.
data_parallel_rank
*
self
.
micro_batch_size
start_idx
,
end_idx
=
self
.
get_start_end_idx
()
end_idx
=
start_idx
+
self
.
micro_batch_size
yield
batch
[
start_idx
:
end_idx
]
yield
batch
[
start_idx
:
end_idx
]
batch
=
[]
batch
=
[]
# Check the last partial batch and see drop_last is set
if
len
(
batch
)
>
0
and
not
self
.
drop_last
:
start_idx
,
end_idx
=
self
.
get_start_end_idx
()
yield
batch
[
start_idx
:
end_idx
]
class
MegatronPretrainingRandomSampler
:
class
MegatronPretrainingRandomSampler
:
...
...
megatron/data/ict_dataset.py
View file @
ee7b19e7
...
@@ -9,6 +9,16 @@ from megatron import get_args
...
@@ -9,6 +9,16 @@ from megatron import get_args
from
megatron.data.dataset_utils
import
get_indexed_dataset_
from
megatron.data.dataset_utils
import
get_indexed_dataset_
from
megatron.data.realm_dataset_utils
import
get_block_samples_mapping
from
megatron.data.realm_dataset_utils
import
get_block_samples_mapping
def
make_attention_mask
(
source_block
,
target_block
):
"""
Returns a 2-dimensional (2-D) attention mask
:param source_block: 1-D array
:param target_block: 1-D array
"""
mask
=
(
target_block
[
None
,
:]
>=
1
)
*
(
source_block
[:,
None
]
>=
1
)
mask
=
mask
.
astype
(
np
.
int64
)
# (source_length, target_length)
return
mask
def
get_ict_dataset
(
use_titles
=
True
,
query_in_block_prob
=
1
):
def
get_ict_dataset
(
use_titles
=
True
,
query_in_block_prob
=
1
):
"""Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block())
"""Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block())
...
@@ -39,7 +49,7 @@ class ICTDataset(Dataset):
...
@@ -39,7 +49,7 @@ class ICTDataset(Dataset):
"""Dataset containing sentences and their blocks for an inverse cloze task."""
"""Dataset containing sentences and their blocks for an inverse cloze task."""
def
__init__
(
self
,
name
,
block_dataset
,
title_dataset
,
data_prefix
,
def
__init__
(
self
,
name
,
block_dataset
,
title_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
,
max_seq_length
,
query_in_block_prob
,
num_epochs
,
max_num_samples
,
max_seq_length
,
query_in_block_prob
,
seed
,
use_titles
=
True
,
use_one_sent_docs
=
False
):
seed
,
use_titles
=
True
,
use_one_sent_docs
=
False
,
binary_head
=
False
):
self
.
name
=
name
self
.
name
=
name
self
.
seed
=
seed
self
.
seed
=
seed
self
.
max_seq_length
=
max_seq_length
self
.
max_seq_length
=
max_seq_length
...
@@ -93,14 +103,20 @@ class ICTDataset(Dataset):
...
@@ -93,14 +103,20 @@ class ICTDataset(Dataset):
block
=
list
(
itertools
.
chain
(
*
block
))[:
self
.
max_seq_length
-
title_pad_offset
]
block
=
list
(
itertools
.
chain
(
*
block
))[:
self
.
max_seq_length
-
title_pad_offset
]
query_tokens
,
query_pad_mask
=
self
.
concat_and_pad_tokens
(
query
)
query_tokens
,
query_pad_mask
=
self
.
concat_and_pad_tokens
(
query
)
block_tokens
,
block_pad_mask
=
self
.
concat_and_pad_tokens
(
block
,
title
)
context_tokens
,
context_pad_mask
=
self
.
concat_and_pad_tokens
(
block
,
title
)
query_mask
=
make_attention_mask
(
query_tokens
,
query_tokens
)
context_mask
=
make_attention_mask
(
context_tokens
,
context_tokens
)
block_data
=
sample_data
.
as_array
()
block_data
=
sample_data
.
as_array
()
sample
=
{
sample
=
{
'query_tokens'
:
query_tokens
,
'query_tokens'
:
query_tokens
,
'query_mask'
:
query_mask
,
'query_pad_mask'
:
query_pad_mask
,
'query_pad_mask'
:
query_pad_mask
,
'block_tokens'
:
block_tokens
,
'context_tokens'
:
context_tokens
,
'block_pad_mask'
:
block_pad_mask
,
'context_mask'
:
context_mask
,
'context_pad_mask'
:
context_pad_mask
,
'block_data'
:
block_data
,
'block_data'
:
block_data
,
}
}
...
...
megatron/data/orqa_wiki_dataset.py
0 → 100644
View file @
ee7b19e7
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wikipedia dataset from DPR code for ORQA."""
from
abc
import
ABC
import
csv
import
numpy
as
np
import
random
import
torch
from
torch.utils.data
import
Dataset
from
megatron
import
print_rank_0
,
get_args
,
get_tokenizer
,
mpu
from
megatron.data.biencoder_dataset_utils
import
make_attention_mask
def
get_open_retrieval_wiki_dataset
():
args
=
get_args
()
tokenizer
=
get_tokenizer
()
dataset
=
OpenRetrievalEvidenceDataset
(
'2018 Wikipedia from DPR codebase'
,
'evidence'
,
args
.
evidence_data_path
,
tokenizer
,
args
.
retriever_seq_length
)
return
dataset
def
get_open_retrieval_batch
(
data_iterator
):
# Items and their type.
keys
=
[
'row_id'
,
'context'
,
'context_mask'
,
'context_types'
,
'context_pad_mask'
]
datatype
=
torch
.
int64
# Broadcast data.
data
=
None
if
data_iterator
is
None
else
next
(
data_iterator
)
data_b
=
mpu
.
broadcast_data
(
keys
,
data
,
datatype
)
# Unpack.
row_id
=
data_b
[
'row_id'
].
long
()
context
=
data_b
[
'context'
].
long
()
# TODO: make the context mask a binary one
context_mask
=
(
data_b
[
'context_mask'
]
<
0.5
)
context_types
=
data_b
[
'context_types'
].
long
()
context_pad_mask
=
data_b
[
'context_pad_mask'
].
long
()
return
row_id
,
context
,
context_mask
,
context_types
,
context_pad_mask
def
build_tokens_types_paddings_from_text
(
row
,
tokenizer
,
max_seq_length
):
"""Build token types and paddings, trim if needed, and pad if needed."""
title_ids
=
tokenizer
.
tokenize
(
row
[
'title'
])
context_ids
=
tokenizer
.
tokenize
(
row
[
'text'
])
# Appending the title of the context at front
extended_context_ids
=
title_ids
+
[
tokenizer
.
sep_id
]
+
context_ids
context_ids
,
context_types
,
context_pad_mask
=
\
build_tokens_types_paddings_from_ids
(
extended_context_ids
,
max_seq_length
,
tokenizer
.
cls
,
tokenizer
.
sep
,
tokenizer
.
pad
)
return
context_ids
,
context_types
,
context_pad_mask
# noinspection DuplicatedCode
def
build_tokens_types_paddings_from_ids
(
text_ids
,
max_seq_length
,
cls_id
,
sep_id
,
pad_id
):
"""Build token types and paddings, trim if needed, and pad if needed."""
enc_ids
=
[]
tokentypes_enc
=
[]
# [CLS].
enc_ids
.
append
(
cls_id
)
tokentypes_enc
.
append
(
0
)
# A.
len_src
=
len
(
text_ids
)
enc_ids
.
extend
(
text_ids
)
tokentypes_enc
.
extend
([
0
]
*
len_src
)
# Cap the size.
if
len
(
enc_ids
)
>
max_seq_length
-
1
:
enc_ids
=
enc_ids
[
0
:
max_seq_length
-
1
]
tokentypes_enc
=
tokentypes_enc
[
0
:
max_seq_length
-
1
]
# [SEP].
enc_ids
.
append
(
sep_id
)
tokentypes_enc
.
append
(
0
)
num_tokens_enc
=
len
(
enc_ids
)
# Padding.
padding_length
=
max_seq_length
-
len
(
enc_ids
)
if
padding_length
>
0
:
enc_ids
.
extend
([
pad_id
]
*
padding_length
)
tokentypes_enc
.
extend
([
pad_id
]
*
padding_length
)
pad_mask
=
([
1
]
*
num_tokens_enc
)
+
([
0
]
*
padding_length
)
pad_mask
=
np
.
array
(
pad_mask
,
dtype
=
np
.
int64
)
return
enc_ids
,
tokentypes_enc
,
pad_mask
def
build_sample
(
row_id
,
context_ids
,
context_types
,
context_pad_mask
):
"""Convert to numpy and return a sample consumed by the batch producer."""
context_ids
=
np
.
array
(
context_ids
,
dtype
=
np
.
int64
)
context_types
=
np
.
array
(
context_types
,
dtype
=
np
.
int64
)
context_mask
=
make_attention_mask
(
context_ids
,
context_ids
)
sample
=
({
'row_id'
:
row_id
,
'context'
:
context_ids
,
'context_mask'
:
context_mask
,
'context_types'
:
context_types
,
'context_pad_mask'
:
context_pad_mask
})
return
sample
class
OpenRetrievalEvidenceDataset
(
ABC
,
Dataset
):
"""Open Retrieval Evidence dataset class."""
def
__init__
(
self
,
task_name
,
dataset_name
,
datapath
,
tokenizer
,
max_seq_length
):
# Store inputs.
self
.
task_name
=
task_name
self
.
dataset_name
=
dataset_name
self
.
tokenizer
=
tokenizer
self
.
max_seq_length
=
max_seq_length
print_rank_0
(
' > building {} dataset for {}:'
.
format
(
self
.
task_name
,
self
.
dataset_name
))
# Process the files.
print_rank_0
(
datapath
)
self
.
samples
,
self
.
id2text
=
self
.
process_samples_from_single_path
(
datapath
)
args
=
get_args
()
if
args
.
sample_rate
<
1
:
# subsample
k
=
int
(
len
(
self
.
samples
)
*
args
.
sample_rate
)
self
.
samples
=
random
.
sample
(
self
.
samples
,
k
)
print_rank_0
(
' >> total number of samples: {}'
.
format
(
len
(
self
.
samples
)))
def
__len__
(
self
):
return
len
(
self
.
samples
)
def
__getitem__
(
self
,
idx
):
row
=
self
.
samples
[
idx
]
context_ids
,
context_types
,
context_pad_mask
=
\
build_tokens_types_paddings_from_text
(
row
,
self
.
tokenizer
,
self
.
max_seq_length
)
sample
=
build_sample
(
row
[
'doc_id'
],
context_ids
,
context_types
,
context_pad_mask
)
return
sample
@
staticmethod
def
process_samples_from_single_path
(
filename
):
print_rank_0
(
' > Processing {} ...'
.
format
(
filename
))
total
=
0
rows
=
[]
id2text
=
{}
with
open
(
filename
)
as
tsvfile
:
reader
=
csv
.
reader
(
tsvfile
,
delimiter
=
'
\t
'
)
next
(
reader
,
None
)
# skip the headers
for
row
in
reader
:
# file format: doc_id, doc_text, title
doc_id
=
int
(
row
[
0
])
text
=
row
[
1
]
title
=
row
[
2
]
rows
.
append
({
'doc_id'
:
doc_id
,
'text'
:
text
,
'title'
:
title
})
assert
doc_id
not
in
id2text
id2text
[
doc_id
]
=
(
text
,
title
)
total
+=
1
if
total
%
100000
==
0
:
print_rank_0
(
' > processed {} rows so far ...'
.
format
(
total
))
print_rank_0
(
' >> processed {} samples.'
.
format
(
len
(
rows
)))
return
rows
,
id2text
megatron/data/realm_index.py
View file @
ee7b19e7
...
@@ -14,34 +14,36 @@ def detach(tensor):
...
@@ -14,34 +14,36 @@ def detach(tensor):
return
tensor
.
detach
().
cpu
().
numpy
()
return
tensor
.
detach
().
cpu
().
numpy
()
class
BlockData
(
object
):
class
OpenRetreivalDataStore
(
object
):
"""Serializable data structure for holding data for blocks -- embeddings and necessary metadata for REALM"""
"""
def
__init__
(
self
,
block_data_path
=
None
,
load_from_path
=
True
,
rank
=
None
):
Serializable data structure for holding data for blocks --
embeddings and necessary metadata for Retriever
"""
def
__init__
(
self
,
embedding_path
=
None
,
load_from_path
=
True
,
rank
=
None
):
self
.
embed_data
=
dict
()
self
.
embed_data
=
dict
()
self
.
meta_data
=
dict
()
if
embedding_path
is
None
:
if
block_data_path
is
None
:
args
=
get_args
()
args
=
get_args
()
block_data
_path
=
args
.
block_data
_path
embedding
_path
=
args
.
embedding
_path
rank
=
args
.
rank
rank
=
args
.
rank
self
.
block_data_path
=
block_data
_path
self
.
embedding_path
=
embedding
_path
self
.
rank
=
rank
self
.
rank
=
rank
if
load_from_path
:
if
load_from_path
:
self
.
load_from_file
()
self
.
load_from_file
()
block_data_name
=
os
.
path
.
splitext
(
self
.
block_data
_path
)[
0
]
block_data_name
=
os
.
path
.
splitext
(
self
.
embedding
_path
)[
0
]
self
.
temp_dir_name
=
block_data_name
+
'_tmp'
self
.
temp_dir_name
=
block_data_name
+
'_tmp'
def
state
(
self
):
def
state
(
self
):
return
{
return
{
'embed_data'
:
self
.
embed_data
,
'embed_data'
:
self
.
embed_data
,
'meta_data'
:
self
.
meta_data
,
}
}
def
clear
(
self
):
def
clear
(
self
):
"""Clear the embedding data structures to save memory.
"""
The metadata ends up getting used, and is also much smaller in dimensionality
Clear the embedding data structures to save memory.
so it isn't really worth clearing.
The metadata ends up getting used, and is also much smaller in
dimensionality so it isn't really worth clearing.
"""
"""
self
.
embed_data
=
dict
()
self
.
embed_data
=
dict
()
...
@@ -50,38 +52,39 @@ class BlockData(object):
...
@@ -50,38 +52,39 @@ class BlockData(object):
if
mpu
.
is_unitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
if
mpu
.
is_unitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
"
\n
> Unpickling BlockData"
,
flush
=
True
)
print
(
"
\n
> Unpickling BlockData"
,
flush
=
True
)
state_dict
=
pickle
.
load
(
open
(
self
.
block_data
_path
,
'rb'
))
state_dict
=
pickle
.
load
(
open
(
self
.
embedding
_path
,
'rb'
))
if
mpu
.
is_unitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
if
mpu
.
is_unitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
">> Finished unpickling BlockData
\n
"
,
flush
=
True
)
print
(
">> Finished unpickling BlockData
\n
"
,
flush
=
True
)
self
.
embed_data
=
state_dict
[
'embed_data'
]
self
.
embed_data
=
state_dict
[
'embed_data'
]
self
.
meta_data
=
state_dict
[
'meta_data'
]
def
add_block_data
(
self
,
block_indices
,
block_embeds
,
block_metas
,
allow_overwrite
=
False
):
def
add_block_data
(
self
,
row_id
,
block_embeds
,
allow_overwrite
=
False
):
"""Add data for set of blocks
"""
:param block_indices: 1D array of unique int ids for the blocks
Add data for set of blocks
:param row_id: 1D array of unique int ids for the blocks
:param block_embeds: 2D array of embeddings of the blocks
:param block_embeds: 2D array of embeddings of the blocks
:param block_metas: 2D array of metadata for the blocks.
In the case of retriever this will be [start_idx, end_idx, doc_idx]
In the case of REALM this will be [start_idx, end_idx, doc_idx]
"""
"""
for
idx
,
embed
,
meta
in
zip
(
block_indices
,
block_embeds
,
block_meta
s
):
for
idx
,
embed
in
zip
(
row_id
,
block_embed
s
):
if
not
allow_overwrite
and
idx
in
self
.
embed_data
:
if
not
allow_overwrite
and
idx
in
self
.
embed_data
:
raise
ValueError
(
"Unexpectedly tried to overwrite block data"
)
raise
ValueError
(
"Unexpectedly tried to overwrite block data"
)
self
.
embed_data
[
idx
]
=
np
.
float16
(
embed
)
self
.
embed_data
[
idx
]
=
np
.
float16
(
embed
)
self
.
meta_data
[
idx
]
=
meta
def
save_shard
(
self
):
def
save_shard
(
self
):
"""Save the block data that was created this in this process"""
"""
Save the block data that was created this in this process
"""
if
not
os
.
path
.
isdir
(
self
.
temp_dir_name
):
if
not
os
.
path
.
isdir
(
self
.
temp_dir_name
):
os
.
makedirs
(
self
.
temp_dir_name
,
exist_ok
=
True
)
os
.
makedirs
(
self
.
temp_dir_name
,
exist_ok
=
True
)
# save the data for each shard
# save the data for each shard
with
open
(
'{}/{}.pkl'
.
format
(
self
.
temp_dir_name
,
self
.
rank
),
'wb'
)
as
data_file
:
with
open
(
'{}/{}.pkl'
.
format
(
self
.
temp_dir_name
,
self
.
rank
),
'wb'
)
\
pickle
.
dump
(
self
.
state
(),
data_file
)
as
writer
:
pickle
.
dump
(
self
.
state
(),
writer
)
def
merge_shards_and_save
(
self
):
def
merge_shards_and_save
(
self
):
"""
Combine all the shards made using
self.
save_shard
()"""
#
Combine all the shards made using save_shard
shard_names
=
os
.
listdir
(
self
.
temp_dir_name
)
shard_names
=
os
.
listdir
(
self
.
temp_dir_name
)
seen_own_shard
=
False
seen_own_shard
=
False
...
@@ -96,15 +99,15 @@ class BlockData(object):
...
@@ -96,15 +99,15 @@ class BlockData(object):
old_size
=
len
(
self
.
embed_data
)
old_size
=
len
(
self
.
embed_data
)
shard_size
=
len
(
data
[
'embed_data'
])
shard_size
=
len
(
data
[
'embed_data'
])
# add the shard's data and check to make sure there is no overlap
# add the shard's data and check to make sure there
# is no overlap
self
.
embed_data
.
update
(
data
[
'embed_data'
])
self
.
embed_data
.
update
(
data
[
'embed_data'
])
self
.
meta_data
.
update
(
data
[
'meta_data'
])
assert
len
(
self
.
embed_data
)
==
old_size
+
shard_size
assert
len
(
self
.
embed_data
)
==
old_size
+
shard_size
assert
seen_own_shard
assert
seen_own_shard
# save the consolidated shards and remove temporary directory
# save the consolidated shards and remove temporary directory
with
open
(
self
.
block_data
_path
,
'wb'
)
as
final_file
:
with
open
(
self
.
embedding
_path
,
'wb'
)
as
final_file
:
pickle
.
dump
(
self
.
state
(),
final_file
)
pickle
.
dump
(
self
.
state
(),
final_file
)
shutil
.
rmtree
(
self
.
temp_dir_name
,
ignore_errors
=
True
)
shutil
.
rmtree
(
self
.
temp_dir_name
,
ignore_errors
=
True
)
...
@@ -113,18 +116,22 @@ class BlockData(object):
...
@@ -113,18 +116,22 @@ class BlockData(object):
class
FaissMIPSIndex
(
object
):
class
FaissMIPSIndex
(
object
):
"""Wrapper object for a BlockData which similarity search via FAISS under the hood"""
"""
def
__init__
(
self
,
embed_size
,
block_data
=
None
,
use_gpu
=
False
):
Wrapper object for a BlockData which similarity search via FAISS under the hood
"""
def
__init__
(
self
,
embed_size
,
embed_data
=
None
,
use_gpu
=
False
):
self
.
embed_size
=
embed_size
self
.
embed_size
=
embed_size
self
.
block
_data
=
block
_data
self
.
embed
_data
=
embed
_data
self
.
use_gpu
=
use_gpu
self
.
use_gpu
=
use_gpu
self
.
id_map
=
dict
()
self
.
block_
mips_index
=
None
self
.
mips_index
=
None
self
.
_set_
block
_index
()
self
.
_set_
mips
_index
()
def
_set_block_index
(
self
):
def
_set_mips_index
(
self
):
"""Create a Faiss Flat index with inner product as the metric to search against"""
"""
Create a Faiss Flat index with inner product as the metric
to search against
"""
try
:
try
:
import
faiss
import
faiss
except
ImportError
:
except
ImportError
:
...
@@ -132,85 +139,86 @@ class FaissMIPSIndex(object):
...
@@ -132,85 +139,86 @@ class FaissMIPSIndex(object):
if
mpu
.
is_unitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
if
mpu
.
is_unitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
"
\n
> Building index"
,
flush
=
True
)
print
(
"
\n
> Building index"
,
flush
=
True
)
self
.
block_mips_index
=
faiss
.
index_factory
(
self
.
embed_size
,
'Flat'
,
faiss
.
METRIC_INNER_PRODUCT
)
cpu_index
=
faiss
.
IndexFlatIP
(
self
.
embed_size
)
if
self
.
use_gpu
:
if
self
.
use_gpu
:
# create resources and config for GpuIndex
# create resources and config for GpuIndex
res
=
faiss
.
StandardGpuResources
()
config
=
faiss
.
GpuMultipleClonerOptions
()
config
=
faiss
.
GpuIndexFlatConfig
()
config
.
shard
=
True
config
.
device
=
torch
.
cuda
.
current_device
()
config
.
useFloat16
=
True
config
.
useFloat16
=
True
gpu_index
=
faiss
.
index_cpu_to_all_gpus
(
cpu_index
,
co
=
config
)
self
.
block_
mips_index
=
faiss
.
Gpu
Index
Flat
(
res
,
self
.
block_mips_index
,
config
)
self
.
mips_index
=
faiss
.
Index
IDMap
(
gpu_index
)
if
mpu
.
is_unitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
if
mpu
.
is_unitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
">> Initialized index on GPU
{}"
.
format
(
self
.
block_mips_index
.
getDevice
())
,
flush
=
True
)
print
(
">> Initialized index on GPU
"
,
flush
=
True
)
else
:
else
:
# CPU index supports IDs so wrap with IDMap
# CPU index supports IDs so wrap with IDMap
self
.
block_
mips_index
=
faiss
.
IndexIDMap
(
self
.
block_mips
_index
)
self
.
mips_index
=
faiss
.
IndexIDMap
(
cpu
_index
)
if
mpu
.
is_unitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
if
mpu
.
is_unitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
">> Initialized index on CPU"
,
flush
=
True
)
print
(
">> Initialized index on CPU"
,
flush
=
True
)
# if we were constructed with a BlockData, then automatically load it when the FAISS structure is built
# if we were constructed with a BlockData, then automatically load it
if
self
.
block_data
is
not
None
:
# when the FAISS structure is built
self
.
add_block_embed_data
(
self
.
block_data
)
if
self
.
embed_data
is
not
None
:
self
.
add_embed_data
(
self
.
embed_data
)
def
reset_index
(
self
):
def
reset_index
(
self
):
"""Delete existing index and create anew"""
"""Delete existing index and create a
new"""
del
self
.
block_
mips_index
del
self
.
mips_index
# reset the block data so that _set_block_index will reload it as well
# reset the block data so that _set_block_index will reload it as well
if
self
.
block
_data
is
not
None
:
if
self
.
embed
_data
is
not
None
:
block
_data_path
=
self
.
block
_data
.
block_data
_path
embed
_data_path
=
self
.
embed
_data
.
embedding
_path
del
self
.
block
_data
del
self
.
embed
_data
self
.
block
_data
=
BlockData
(
block
_data_path
)
self
.
embed
_data
=
OpenRetreivalDataStore
(
embed
_data_path
)
self
.
_set_
block
_index
()
self
.
_set_
mips
_index
()
def
add_block_embed_data
(
self
,
all_block_data
):
def
update_index
(
self
):
"""Delete existing index and create a new"""
del
self
.
mips_index
# reset the block data so that _set_mips_index will reload it as well
if
self
.
embed_data
is
not
None
:
self
.
embed_data
.
load_from_file
()
self
.
_set_mips_index
()
def
add_embed_data
(
self
,
all_embed_data
):
"""Add the embedding of each block to the underlying FAISS index"""
"""Add the embedding of each block to the underlying FAISS index"""
# this assumes the embed_data is a dict : {int: np.array<float>}
# this assumes the embed_data is a dict : {int: np.array<float>}
block_indices
,
block_embeds
=
zip
(
*
all_
block
_data
.
embed_data
.
items
())
block_indices
,
block_embeds
=
zip
(
*
all_
embed
_data
.
embed_data
.
items
())
# the embeddings have to be entered in as float32 even though the math internally is done with float16.
# the embeddings have to be entered in as float32 even though the math
block_embeds_arr
=
np
.
float32
(
np
.
array
(
block_embeds
))
# internally is done with float16.
block_indices_arr
=
np
.
array
(
block_indices
)
embeds_arr
=
np
.
float32
(
np
.
array
(
block_embeds
))
indices_arr
=
np
.
array
(
block_indices
)
# faiss GpuIndex doesn't work with IDMap wrapper so store ids to map back with
if
self
.
use_gpu
:
for
i
,
idx
in
enumerate
(
block_indices
):
self
.
id_map
[
i
]
=
idx
# we no longer need the embedding data since it's in the index now
# we no longer need the embedding data since it's in the index now
all_
block
_data
.
clear
()
all_
embed
_data
.
clear
()
if
self
.
use_gpu
:
self
.
mips_index
.
add_with_ids
(
embeds_arr
,
indices_arr
)
self
.
block_mips_index
.
add
(
block_embeds_arr
)
else
:
self
.
block_mips_index
.
add_with_ids
(
block_embeds_arr
,
block_indices_arr
)
if
mpu
.
is_unitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
if
mpu
.
is_unitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
">>> Finished adding block data to index"
,
flush
=
True
)
print
(
">>> Finished adding block data to index"
,
flush
=
True
)
def
search_mips_index
(
self
,
query_embeds
,
top_k
,
reconstruct
=
True
):
def
search_mips_index
(
self
,
query_embeds
,
top_k
,
reconstruct
=
True
):
"""Get the top-k blocks by the index distance metric.
"""
Get the top-k blocks by the index distance metric.
:param reconstruct: if True: return a [num_queries x k x embed_dim] array of blocks
:param reconstruct: if True: return a [num_queries x k x embed_dim]
if False: return [num_queries x k] array of distances, and another for indices
array of blocks
if False: return [num_queries x k] array of
distances, and another for indices
"""
"""
query_embeds
=
np
.
float32
(
detach
(
query_embeds
))
query_embeds
=
np
.
float32
(
detach
(
query_embeds
))
if
reconstruct
:
if
reconstruct
:
# get the vectors themselves
# get the vectors themselves
top_k_block_embeds
=
self
.
block_mips_index
.
search_and_reconstruct
(
query_embeds
,
top_k
)
top_k_block_embeds
=
self
.
mips_index
.
search_and_reconstruct
(
\
query_embeds
,
top_k
)
return
top_k_block_embeds
return
top_k_block_embeds
else
:
else
:
# get distances and indices of closest vectors
# get distances and indices of closest vectors
distances
,
block_indices
=
self
.
block_mips_index
.
search
(
query_embeds
,
top_k
)
distances
,
block_indices
=
self
.
mips_index
.
search
(
query_embeds
,
top_k
)
if
self
.
use_gpu
:
fresh_indices
=
np
.
zeros
(
block_indices
.
shape
)
for
i
,
j
in
itertools
.
product
(
block_indices
.
shape
):
fresh_indices
[
i
,
j
]
=
self
.
id_map
[
block_indices
[
i
,
j
]]
block_indices
=
fresh_indices
return
distances
,
block_indices
return
distances
,
block_indices
megatron/fused_kernels/__init__.py
View file @
ee7b19e7
...
@@ -13,114 +13,97 @@
...
@@ -13,114 +13,97 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
os
import
pathlib
import
pathlib
import
subprocess
import
subprocess
import
os
from
torch.utils
import
cpp_extension
from
torch.utils
import
cpp_extension
# Setting this param to a list has a problem of generating
# Setting this param to a list has a problem of generating
different
#
different
compilation commands (with diferent order of architectures)
# compilation commands (with diferent order of architectures)
and
#
and
leading to recompilation of fused kernels.
# leading to recompilation of fused kernels.
Set it to empty string
#
set it to empty string to avoid recompilatio
n
#
to avoid recompilation and assign arch flags explicity i
n
#
and assign arch flags explicity in
extra_cuda_cflags below
# extra_cuda_cflags below
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
""
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
""
def
get_cuda_bare_metal_version
(
cuda_dir
):
raw_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
output
=
raw_output
.
split
()
release_idx
=
output
.
index
(
"release"
)
+
1
release
=
output
[
release_idx
].
split
(
"."
)
bare_metal_major
=
release
[
0
]
bare_metal_minor
=
release
[
1
][
0
]
return
raw_output
,
bare_metal_major
,
bare_metal_minor
def
create_build_dir
(
buildpath
):
try
:
os
.
mkdir
(
buildpath
)
except
OSError
:
if
not
os
.
path
.
isdir
(
buildpath
):
print
(
f
"Creation of the build directory
{
buildpath
}
failed"
)
def
load
_scaled_upper_triang_masked_softmax_fusion_kernel
(
):
def
load
(
args
):
# Check
,
if
CUDA
11 is installed for compute capability 8.0
# Check if
cuda
11 is installed for compute capability 8.0
cc_flag
=
[]
cc_flag
=
[]
_
,
bare_metal_major
,
_
=
get_cuda_bare_metal_version
(
cpp_extension
.
CUDA_HOME
)
_
,
bare_metal_major
,
_
=
_get_cuda_bare_metal_version
(
cpp_extension
.
CUDA_HOME
)
if
int
(
bare_metal_major
)
>=
11
:
if
int
(
bare_metal_major
)
>=
11
:
cc_flag
.
append
(
'-gencode'
)
cc_flag
.
append
(
'-gencode'
)
cc_flag
.
append
(
'arch=compute_80,code=sm_80'
)
cc_flag
.
append
(
'arch=compute_80,code=sm_80'
)
# Build path
srcpath
=
pathlib
.
Path
(
__file__
).
parent
.
absolute
()
srcpath
=
pathlib
.
Path
(
__file__
).
parent
.
absolute
()
buildpath
=
srcpath
/
'build'
buildpath
=
srcpath
/
'build'
_create_build_dir
(
buildpath
)
create_build_dir
(
buildpath
)
# Helper function to build the kernels.
scaled_upper_triang_masked_softmax_cuda
=
cpp_extension
.
load
(
def
_cpp_extention_load_helper
(
name
,
sources
,
extra_cuda_flags
):
name
=
'scaled_upper_triang_masked_softmax_cuda'
,
return
cpp_extension
.
load
(
name
=
name
,
sources
=
sources
,
build_directory
=
buildpath
,
extra_cflags
=
[
'-O3'
,],
extra_cuda_cflags
=
[
'-O3'
,
'-gencode'
,
'arch=compute_70,code=sm_70'
,
'--use_fast_math'
]
+
extra_cuda_flags
+
cc_flag
,
verbose
=
(
args
.
rank
==
0
)
)
# ==============
# Fused softmax.
# ==============
if
args
.
masked_softmax_fusion
:
extra_cuda_flags
=
[
'-U__CUDA_NO_HALF_OPERATORS__'
,
'-U__CUDA_NO_HALF_CONVERSIONS__'
,
'--expt-relaxed-constexpr'
,
'--expt-extended-lambda'
]
# Upper triangular softmax.
sources
=
[
srcpath
/
'scaled_upper_triang_masked_softmax.cpp'
,
sources
=
[
srcpath
/
'scaled_upper_triang_masked_softmax.cpp'
,
srcpath
/
'scaled_upper_triang_masked_softmax_cuda.cu'
],
srcpath
/
'scaled_upper_triang_masked_softmax_cuda.cu'
]
build_directory
=
buildpath
,
scaled_upper_triang_masked_softmax_cuda
=
_cpp_extention_load_helper
(
extra_cflags
=
[
'-O3'
,],
"scaled_upper_triang_masked_softmax_cuda"
,
extra_cuda_cflags
=
[
'-O3'
,
sources
,
extra_cuda_flags
)
'-gencode'
,
'arch=compute_70,code=sm_70'
,
'-U__CUDA_NO_HALF_OPERATORS__'
,
'-U__CUDA_NO_HALF_CONVERSIONS__'
,
'--expt-relaxed-constexpr'
,
'--expt-extended-lambda'
,
'--use_fast_math'
]
+
cc_flag
)
def
load_scaled_masked_softmax_fusion_kernel
():
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag
=
[]
_
,
bare_metal_major
,
_
=
get_cuda_bare_metal_version
(
cpp_extension
.
CUDA_HOME
)
if
int
(
bare_metal_major
)
>=
11
:
cc_flag
.
append
(
'-gencode'
)
cc_flag
.
append
(
'arch=compute_80,code=sm_80'
)
srcpath
=
pathlib
.
Path
(
__file__
).
parent
.
absolute
()
# Masked softmax.
buildpath
=
srcpath
/
'build'
sources
=
[
srcpath
/
'scaled_masked_softmax.cpp'
,
srcpath
/
'scaled_masked_softmax_cuda.cu'
]
scaled_masked_softmax_cuda
=
_cpp_extention_load_helper
(
"scaled_masked_softmax_cuda"
,
sources
,
extra_cuda_flags
)
create_build_dir
(
buildpath
)
# =================================
# Mixed precision fused layer norm.
# =================================
scaled_upper_triang_masked_softmax_cuda
=
cpp_extension
.
load
(
extra_cuda_flags
=
[
'-maxrregcount=50'
]
name
=
'scaled_masked_softmax_cuda'
,
sources
=
[
srcpath
/
'layer_norm_cuda.cpp'
,
sources
=
[
srcpath
/
'scaled_masked_softmax.cpp'
,
srcpath
/
'layer_norm_cuda_kernel.cu'
]
srcpath
/
'scaled_masked_softmax_cuda.cu'
],
fused_mix_prec_layer_norm_cuda
=
_cpp_extention_load_helper
(
build_directory
=
buildpath
,
"fused_mix_prec_layer_norm_cuda"
,
sources
,
extra_cuda_flags
)
extra_cflags
=
[
'-O3'
,],
extra_cuda_cflags
=
[
'-O3'
,
'-gencode'
,
'arch=compute_70,code=sm_70'
,
'-U__CUDA_NO_HALF_OPERATORS__'
,
'-U__CUDA_NO_HALF_CONVERSIONS__'
,
'--expt-relaxed-constexpr'
,
'--expt-extended-lambda'
,
'--use_fast_math'
]
+
cc_flag
)
def
load_fused_mix_prec_layer_norm_kernel
():
def
_get_cuda_bare_metal_version
(
cuda_dir
):
raw_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
output
=
raw_output
.
split
()
release_idx
=
output
.
index
(
"release"
)
+
1
release
=
output
[
release_idx
].
split
(
"."
)
bare_metal_major
=
release
[
0
]
bare_metal_minor
=
release
[
1
][
0
]
# Check, if CUDA11 is installed for compute capability 8.0
return
raw_output
,
bare_metal_major
,
bare_metal_minor
cc_flag
=
[]
_
,
bare_metal_major
,
_
=
get_cuda_bare_metal_version
(
cpp_extension
.
CUDA_HOME
)
if
int
(
bare_metal_major
)
>=
11
:
cc_flag
.
append
(
'-gencode'
)
cc_flag
.
append
(
'arch=compute_80,code=sm_80'
)
srcpath
=
pathlib
.
Path
(
__file__
).
parent
.
absolute
()
buildpath
=
srcpath
/
'build'
create_build_dir
(
buildpath
)
def
_create_build_dir
(
buildpath
):
try
:
fused_mix_prec_layer_norm_cuda
=
cpp_extension
.
load
(
os
.
mkdir
(
buildpath
)
name
=
'fused_mix_prec_layer_norm_cuda'
,
except
OSError
:
sources
=
[
srcpath
/
'layer_norm_cuda.cpp'
,
if
not
os
.
path
.
isdir
(
buildpath
):
srcpath
/
'layer_norm_cuda_kernel.cu'
],
print
(
f
"Creation of the build directory
{
buildpath
}
failed"
)
build_directory
=
buildpath
,
extra_cflags
=
[
'-O3'
],
extra_cuda_cflags
=
[
'-O3'
,
'-gencode'
,
'arch=compute_70,code=sm_70'
,
'-maxrregcount=50'
,
'--use_fast_math'
]
+
cc_flag
)
megatron/fused_kernels/layer_norm_cuda.cpp
View file @
ee7b19e7
...
@@ -24,16 +24,12 @@
...
@@ -24,16 +24,12 @@
#include "compat.h"
#include "compat.h"
namespace
{
namespace
{
void
compute_n1_n2
(
void
compute_n1_n2
(
at
::
Tensor
input
,
at
::
Tensor
input
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
int
&
n1
,
int
&
n1
,
int
&
n2
)
int
&
n2
)
{
{
int
idiff
=
input
.
ndimension
()
-
normalized_shape
.
size
();
int
idiff
=
input
.
ndimension
()
-
normalized_shape
.
size
();
n2
=
1
;
n2
=
1
;
for
(
int
i
=
0
;
i
<
(
int
)
normalized_shape
.
size
();
++
i
)
{
for
(
int
i
=
0
;
i
<
(
int
)
normalized_shape
.
size
();
++
i
)
{
...
@@ -47,11 +43,7 @@ void compute_n1_n2(
...
@@ -47,11 +43,7 @@ void compute_n1_n2(
}
}
void
check_args
(
void
check_args
(
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
at
::
Tensor
gamma
,
at
::
Tensor
gamma
,
at
::
Tensor
beta
at
::
Tensor
beta
)
)
...
@@ -62,11 +54,7 @@ void check_args(
...
@@ -62,11 +54,7 @@ void check_args(
void
check_args
(
void
check_args
(
at
::
Tensor
input
,
at
::
Tensor
input
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
int
&
n1
,
int
&
n1
,
int
&
n2
int
&
n2
)
)
...
@@ -102,11 +90,7 @@ void check_args(
...
@@ -102,11 +90,7 @@ void check_args(
void
check_args
(
void
check_args
(
at
::
Tensor
input
,
at
::
Tensor
input
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
at
::
Tensor
gamma
,
at
::
Tensor
gamma
,
at
::
Tensor
beta
,
at
::
Tensor
beta
,
int
&
n1
,
int
&
n1
,
...
@@ -125,60 +109,42 @@ void cuda_layer_norm(
...
@@ -125,60 +109,42 @@ void cuda_layer_norm(
at
::
Tensor
*
input
,
at
::
Tensor
*
input
,
int
n1
,
int
n1
,
int
n2
,
int
n2
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
at
::
Tensor
*
gamma
,
at
::
Tensor
*
gamma
,
at
::
Tensor
*
beta
,
at
::
Tensor
*
beta
,
double
epsilon
);
double
epsilon
);
#define CHECK_CUDA(x) TORCH_CHECK(x.
type().
is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std
::
vector
<
at
::
Tensor
>
layer_norm
(
at
::
Tensor
input
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
double
epsilon
)
{
CHECK_INPUT
(
input
);
int
n1
,
n2
;
check_args
(
input
,
normalized_shape
,
n1
,
n2
);
at
::
Tensor
output
=
at
::
empty_like
(
input
);
at
::
Tensor
mean
=
at
::
empty
({
n1
},
input
.
options
().
dtype
(
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
?
at
::
ScalarType
::
Float
:
input
.
scalar_type
()));
at
::
Tensor
invvar
=
at
::
empty_like
(
mean
);
cuda_layer_norm
(
&
output
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
normalized_shape
,
NULL
,
NULL
,
epsilon
);
return
{
output
,
mean
,
invvar
};
}
std
::
vector
<
at
::
Tensor
>
layer_norm_affine
(
std
::
vector
<
at
::
Tensor
>
layer_norm_affine
(
at
::
Tensor
input
,
at
::
Tensor
input
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
at
::
Tensor
gamma
,
at
::
Tensor
gamma
,
at
::
Tensor
beta
,
at
::
Tensor
beta
,
double
epsilon
)
{
double
epsilon
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input
);
CHECK_INPUT
(
gamma
);
CHECK_INPUT
(
gamma
);
CHECK_INPUT
(
beta
);
CHECK_INPUT
(
beta
);
int
n1
,
n2
;
int
n1
,
n2
;
check_args
(
input
,
normalized_shape
,
gamma
,
beta
,
n1
,
n2
);
check_args
(
input
,
normalized_shape
,
gamma
,
beta
,
n1
,
n2
);
at
::
Tensor
output
=
at
::
empty_like
(
input
,
input
.
options
().
dtype
(
at
::
ScalarType
::
Half
));
at
::
Tensor
mean
=
at
::
empty
({
n1
},
input
.
options
().
dtype
(
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
?
at
::
ScalarType
::
Float
:
input
.
scalar_type
()));
at
::
Tensor
output
=
at
::
empty_like
(
input
,
gamma
.
options
().
dtype
(
gamma
.
scalar_type
()));
at
::
Tensor
mean
=
at
::
empty
(
{
n1
},
input
.
options
().
dtype
(
at
::
ScalarType
::
Float
));
at
::
Tensor
invvar
=
at
::
empty_like
(
mean
);
at
::
Tensor
invvar
=
at
::
empty_like
(
mean
);
cuda_layer_norm
(
&
output
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
normalized_shape
,
&
gamma
,
&
beta
,
epsilon
);
cuda_layer_norm
(
&
output
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
normalized_shape
,
&
gamma
,
&
beta
,
epsilon
);
return
{
output
,
mean
,
invvar
};
return
{
output
,
mean
,
invvar
};
}
}
void
cuda_layer_norm_gradient
(
void
cuda_layer_norm_gradient
(
at
::
Tensor
*
dout
,
at
::
Tensor
*
dout
,
at
::
Tensor
*
mean
,
at
::
Tensor
*
mean
,
...
@@ -186,11 +152,7 @@ void cuda_layer_norm_gradient(
...
@@ -186,11 +152,7 @@ void cuda_layer_norm_gradient(
at
::
Tensor
*
input
,
at
::
Tensor
*
input
,
int
n1
,
int
n1
,
int
n2
,
int
n2
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
at
::
Tensor
*
gamma
,
at
::
Tensor
*
gamma
,
at
::
Tensor
*
beta
,
at
::
Tensor
*
beta
,
double
epsilon
,
double
epsilon
,
...
@@ -199,62 +161,41 @@ void cuda_layer_norm_gradient(
...
@@ -199,62 +161,41 @@ void cuda_layer_norm_gradient(
at
::
Tensor
*
grad_beta
at
::
Tensor
*
grad_beta
);
);
at
::
Tensor
layer_norm_gradient
(
at
::
Tensor
dout
,
at
::
Tensor
mean
,
at
::
Tensor
invvar
,
at
::
Tensor
input
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
double
epsilon
)
{
CHECK_INPUT
(
dout
);
CHECK_INPUT
(
mean
);
CHECK_INPUT
(
invvar
);
CHECK_INPUT
(
input
);
int
n1
,
n2
;
check_args
(
input
,
normalized_shape
,
n1
,
n2
);
at
::
Tensor
grad_input
=
at
::
empty_like
(
input
);
cuda_layer_norm_gradient
(
&
dout
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
normalized_shape
,
NULL
,
NULL
,
epsilon
,
&
grad_input
,
NULL
,
NULL
);
return
grad_input
;
}
std
::
vector
<
at
::
Tensor
>
layer_norm_gradient_affine
(
std
::
vector
<
at
::
Tensor
>
layer_norm_gradient_affine
(
at
::
Tensor
dout
,
at
::
Tensor
dout
,
at
::
Tensor
mean
,
at
::
Tensor
mean
,
at
::
Tensor
invvar
,
at
::
Tensor
invvar
,
at
::
Tensor
input
,
at
::
Tensor
input
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
at
::
Tensor
gamma
,
at
::
Tensor
gamma
,
at
::
Tensor
beta
,
at
::
Tensor
beta
,
double
epsilon
)
{
double
epsilon
)
{
CHECK_INPUT
(
dout
);
CHECK_INPUT
(
dout
);
CHECK_INPUT
(
mean
);
CHECK_INPUT
(
mean
);
CHECK_INPUT
(
invvar
);
CHECK_INPUT
(
invvar
);
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input
);
CHECK_INPUT
(
gamma
);
CHECK_INPUT
(
gamma
);
CHECK_INPUT
(
beta
);
CHECK_INPUT
(
beta
);
int
n1
,
n2
;
int
n1
,
n2
;
check_args
(
input
,
normalized_shape
,
gamma
,
beta
,
n1
,
n2
);
check_args
(
input
,
normalized_shape
,
gamma
,
beta
,
n1
,
n2
);
at
::
Tensor
grad_input
=
at
::
empty_like
(
input
);
at
::
Tensor
grad_input
=
at
::
empty_like
(
input
);
at
::
Tensor
grad_gamma
=
at
::
empty_like
(
gamma
);
at
::
Tensor
grad_gamma
=
at
::
empty_like
(
gamma
);
at
::
Tensor
grad_beta
=
at
::
empty_like
(
beta
);
at
::
Tensor
grad_beta
=
at
::
empty_like
(
beta
);
cuda_layer_norm_gradient
(
&
dout
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
normalized_shape
,
&
gamma
,
&
beta
,
epsilon
,
cuda_layer_norm_gradient
(
&
dout
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
&
grad_input
,
&
grad_gamma
,
&
grad_beta
);
normalized_shape
,
&
gamma
,
&
beta
,
epsilon
,
&
grad_input
,
&
grad_gamma
,
&
grad_beta
);
return
{
grad_input
,
grad_gamma
,
grad_beta
};
return
{
grad_input
,
grad_gamma
,
grad_beta
};
}
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward_affine"
,
&
layer_norm_affine
,
"LayerNorm forward (CUDA)"
);
m
.
def
(
"forward_affine"
,
&
layer_norm_affine
,
m
.
def
(
"forward"
,
&
layer_norm
,
"LayerNorm forward (CUDA)"
);
"LayerNorm forward (CUDA)"
);
m
.
def
(
"backward_affine"
,
&
layer_norm_gradient_affine
,
"LayerNorm backward (CUDA)"
);
m
.
def
(
"backward_affine"
,
&
layer_norm_gradient_affine
,
m
.
def
(
"backward"
,
&
layer_norm_gradient
,
"LayerNorm backward (CUDA)"
);
"LayerNorm backward (CUDA)"
);
}
}
megatron/fused_kernels/layer_norm_cuda_kernel.cu
View file @
ee7b19e7
...
@@ -285,15 +285,6 @@ struct SharedMemory <float>
...
@@ -285,15 +285,6 @@ struct SharedMemory <float>
}
}
};
};
template
<
>
struct
SharedMemory
<
double
>
{
__device__
double
*
getPointer
()
{
extern
__shared__
double
s_double
[];
return
s_double
;
}
};
}
}
template
<
typename
T
,
typename
U
,
typename
V
>
__global__
template
<
typename
T
,
typename
U
,
typename
V
>
__global__
...
@@ -656,6 +647,9 @@ void cuComputeGradInput(
...
@@ -656,6 +647,9 @@ void cuComputeGradInput(
}
}
}
}
template
<
typename
T
,
typename
U
,
typename
V
>
template
<
typename
T
,
typename
U
,
typename
V
>
void
HostApplyLayerNorm
(
void
HostApplyLayerNorm
(
V
*
output
,
V
*
output
,
...
@@ -671,7 +665,8 @@ void HostApplyLayerNorm(
...
@@ -671,7 +665,8 @@ void HostApplyLayerNorm(
{
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
dim3
threads
(
32
,
4
,
1
);
const
dim3
threads
(
32
,
4
,
1
);
const
uint64_t
maxGridY
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxGridSize
[
1
];
const
uint64_t
maxGridY
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxGridSize
[
1
];
const
dim3
blocks
(
1
,
std
::
min
((
uint64_t
)
n1
,
maxGridY
),
1
);
const
dim3
blocks
(
1
,
std
::
min
((
uint64_t
)
n1
,
maxGridY
),
1
);
int
nshared
=
int
nshared
=
threads
.
y
>
1
?
threads
.
y
>
1
?
...
@@ -687,6 +682,7 @@ void HostApplyLayerNorm(
...
@@ -687,6 +682,7 @@ void HostApplyLayerNorm(
gamma
,
beta
);
gamma
,
beta
);
}
}
void
cuda_layer_norm
(
void
cuda_layer_norm
(
at
::
Tensor
*
output
,
at
::
Tensor
*
output
,
at
::
Tensor
*
mean
,
at
::
Tensor
*
mean
,
...
@@ -704,21 +700,21 @@ void cuda_layer_norm(
...
@@ -704,21 +700,21 @@ void cuda_layer_norm(
double
epsilon
)
double
epsilon
)
{
{
using
namespace
at
;
using
namespace
at
;
DISPATCH_DOUBLE_FLOAT_AND_HALF
(
input
->
scalar_type
(),
0
,
"layer_norm_cuda_kernel"
,
DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES
(
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
input
->
scalar_type
(),
output
->
scalar_type
(),
"cuda_layer_norm_kernel"
,
using
output_t
=
at
::
Half
;
HostApplyLayerNorm
(
HostApplyLayerNorm
(
output
->
DATA_PTR
<
output_
t
>
(),
output
->
DATA_PTR
<
scalar_t_ou
t
>
(),
mean
->
DATA_PTR
<
accscalar_
t
>
(),
mean
->
DATA_PTR
<
floa
t
>
(),
invvar
->
DATA_PTR
<
accscalar_
t
>
(),
invvar
->
DATA_PTR
<
floa
t
>
(),
input
->
DATA_PTR
<
scalar_t_
0
>
(),
input
->
DATA_PTR
<
scalar_t_
in
>
(),
n1
,
n2
,
n1
,
n2
,
epsilon
,
epsilon
,
gamma
!=
NULL
?
gamma
->
DATA_PTR
<
output_
t
>
()
:
NULL
,
gamma
!=
NULL
?
gamma
->
DATA_PTR
<
scalar_t_ou
t
>
()
:
NULL
,
beta
!=
NULL
?
beta
->
DATA_PTR
<
output_
t
>
()
:
NULL
);
beta
!=
NULL
?
beta
->
DATA_PTR
<
scalar_t_ou
t
>
()
:
NULL
);
)
)
}
}
template
<
typename
T
,
typename
U
,
typename
V
>
template
<
typename
T
,
typename
U
,
typename
V
>
void
HostLayerNormGradient
(
void
HostLayerNormGradient
(
const
V
*
dout
,
const
V
*
dout
,
...
@@ -742,10 +738,12 @@ void HostLayerNormGradient(
...
@@ -742,10 +738,12 @@ void HostLayerNormGradient(
const
int
part_size
=
16
;
const
int
part_size
=
16
;
const
dim3
threads2
(
32
,
4
,
1
);
const
dim3
threads2
(
32
,
4
,
1
);
const
dim3
blocks2
((
n2
+
threads2
.
x
-
1
)
/
threads2
.
x
,
part_size
,
1
);
const
dim3
blocks2
((
n2
+
threads2
.
x
-
1
)
/
threads2
.
x
,
part_size
,
1
);
const
int
nshared2_a
=
2
*
sizeof
(
U
)
*
threads2
.
y
*
threads2
.
y
*
(
threads2
.
x
+
1
);
const
int
nshared2_a
=
2
*
sizeof
(
U
)
*
threads2
.
y
*
threads2
.
y
*
(
threads2
.
x
+
1
);
const
int
nshared2_b
=
threads2
.
x
*
threads2
.
y
*
sizeof
(
U
);
const
int
nshared2_b
=
threads2
.
x
*
threads2
.
y
*
sizeof
(
U
);
const
int
nshared2
=
nshared2_a
>
nshared2_b
?
nshared2_a
:
nshared2_b
;
const
int
nshared2
=
nshared2_a
>
nshared2_b
?
nshared2_a
:
nshared2_b
;
at
::
Tensor
part_grad_gamma
=
at
::
empty
({
part_size
,
n2
},
input
->
options
().
dtype
(
input
->
scalar_type
()
==
at
::
ScalarType
::
Half
?
at
::
ScalarType
::
Float
:
input
->
scalar_type
()));
at
::
Tensor
part_grad_gamma
=
at
::
empty
(
{
part_size
,
n2
},
input
->
options
().
dtype
(
at
::
ScalarType
::
Float
));
at
::
Tensor
part_grad_beta
=
at
::
empty_like
(
part_grad_gamma
);
at
::
Tensor
part_grad_beta
=
at
::
empty_like
(
part_grad_gamma
);
cuComputePartGradGammaBeta
<<<
blocks2
,
threads2
,
nshared2
,
stream
>>>
(
cuComputePartGradGammaBeta
<<<
blocks2
,
threads2
,
nshared2
,
stream
>>>
(
dout
,
dout
,
...
@@ -770,7 +768,8 @@ void HostLayerNormGradient(
...
@@ -770,7 +768,8 @@ void HostLayerNormGradient(
}
}
// compute grad_input
// compute grad_input
const
uint64_t
maxGridY
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxGridSize
[
1
];
const
uint64_t
maxGridY
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxGridSize
[
1
];
const
dim3
blocks1
(
1
,
std
::
min
((
uint64_t
)
n1
,
maxGridY
),
1
);
const
dim3
blocks1
(
1
,
std
::
min
((
uint64_t
)
n1
,
maxGridY
),
1
);
const
dim3
threads1
(
32
,
4
,
1
);
const
dim3
threads1
(
32
,
4
,
1
);
int
nshared
=
int
nshared
=
...
@@ -788,6 +787,7 @@ void HostLayerNormGradient(
...
@@ -788,6 +787,7 @@ void HostLayerNormGradient(
grad_input
);
grad_input
);
}
}
void
cuda_layer_norm_gradient
(
void
cuda_layer_norm_gradient
(
at
::
Tensor
*
dout
,
at
::
Tensor
*
dout
,
at
::
Tensor
*
mean
,
at
::
Tensor
*
mean
,
...
@@ -808,22 +808,22 @@ void cuda_layer_norm_gradient(
...
@@ -808,22 +808,22 @@ void cuda_layer_norm_gradient(
at
::
Tensor
*
grad_beta
)
at
::
Tensor
*
grad_beta
)
{
{
using
namespace
at
;
using
namespace
at
;
DISPATCH_FLOAT_
AND_HALF
(
input
->
scalar_type
(),
0
,
"cuComputeGradInput"
,
DISPATCH_FLOAT_
HALF_AND_BFLOAT_INOUT_TYPES
(
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
input
->
scalar_type
(),
gamma
->
scalar_type
(),
using
output_t
=
at
::
Half
;
"cuda_layer_norm_gradient_kernel"
,
HostLayerNormGradient
(
HostLayerNormGradient
(
dout
->
DATA_PTR
<
output_
t
>
(),
dout
->
DATA_PTR
<
scalar_t_ou
t
>
(),
mean
->
DATA_PTR
<
accscalar_
t
>
(),
mean
->
DATA_PTR
<
floa
t
>
(),
invvar
->
DATA_PTR
<
accscalar_
t
>
(),
invvar
->
DATA_PTR
<
floa
t
>
(),
input
,
input
,
n1
,
n2
,
n1
,
n2
,
// TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
// TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
// if gamma Tensor is NULL on input.
// if gamma Tensor is NULL on input.
gamma
!=
NULL
?
gamma
->
DATA_PTR
<
output_
t
>
()
:
NULL
,
gamma
!=
NULL
?
gamma
->
DATA_PTR
<
scalar_t_ou
t
>
()
:
NULL
,
gamma
!=
NULL
?
beta
->
DATA_PTR
<
output_
t
>
()
:
NULL
,
gamma
!=
NULL
?
beta
->
DATA_PTR
<
scalar_t_ou
t
>
()
:
NULL
,
epsilon
,
epsilon
,
grad_input
->
DATA_PTR
<
scalar_t_
0
>
(),
grad_input
->
DATA_PTR
<
scalar_t_
in
>
(),
gamma
!=
NULL
?
grad_gamma
->
DATA_PTR
<
output_
t
>
()
:
NULL
,
gamma
!=
NULL
?
grad_gamma
->
DATA_PTR
<
scalar_t_ou
t
>
()
:
NULL
,
gamma
!=
NULL
?
grad_beta
->
DATA_PTR
<
output_
t
>
()
:
NULL
);
gamma
!=
NULL
?
grad_beta
->
DATA_PTR
<
scalar_t_ou
t
>
()
:
NULL
);
)
)
}
}
megatron/fused_kernels/scaled_masked_softmax.cpp
View file @
ee7b19e7
...
@@ -37,8 +37,9 @@ torch::Tensor fwd(
...
@@ -37,8 +37,9 @@ torch::Tensor fwd(
torch
::
Tensor
const
&
mask
,
torch
::
Tensor
const
&
mask
,
float
scale_factor
)
{
float
scale_factor
)
{
AT_ASSERTM
(
input
.
dim
()
==
4
,
"expected 4D tensor"
);
AT_ASSERTM
(
input
.
dim
()
==
4
,
"expected 4D tensor"
);
AT_ASSERTM
(
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
,
AT_ASSERTM
((
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
"Only HALF is supported"
);
(
input
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
AT_ASSERTM
(
mask
.
dim
()
==
4
,
"expected 4D tensor"
);
AT_ASSERTM
(
mask
.
dim
()
==
4
,
"expected 4D tensor"
);
return
fwd_cuda
(
input
,
mask
,
scale_factor
);
return
fwd_cuda
(
input
,
mask
,
scale_factor
);
...
@@ -52,10 +53,12 @@ torch::Tensor bwd(
...
@@ -52,10 +53,12 @@ torch::Tensor bwd(
AT_ASSERTM
(
output_grads
.
dim
()
==
4
,
"expected 3D tensor"
);
AT_ASSERTM
(
output_grads
.
dim
()
==
4
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
4
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
4
,
"expected 3D tensor"
);
AT_ASSERTM
(
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
Half
,
AT_ASSERTM
((
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
"Only HALF is supported"
);
(
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
AT_ASSERTM
(
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
Half
,
"Only fp16 and bf16 are supported"
);
"Only HALF is supported"
);
AT_ASSERTM
((
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
return
bwd_cuda
(
output_grads
,
softmax_results
,
scale_factor
);
return
bwd_cuda
(
output_grads
,
softmax_results
,
scale_factor
);
}
}
...
...
megatron/fused_kernels/scaled_masked_softmax.h
View file @
ee7b19e7
...
@@ -26,6 +26,27 @@
...
@@ -26,6 +26,27 @@
namespace
{
namespace
{
template
<
typename
Datatype
,
int
ELEMENTS_PER_LDG
>
__device__
__inline__
void
copy_vector
(
Datatype
*
dst
,
const
Datatype
*
src
);
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
BFloat16
,
1
>
(
c10
::
BFloat16
*
dst
,
const
c10
::
BFloat16
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
BFloat16
,
4
>
(
c10
::
BFloat16
*
dst
,
const
c10
::
BFloat16
*
src
)
{
*
((
float2
*
)
dst
)
=
*
((
float2
*
)
src
);
}
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
Half
,
1
>
(
c10
::
Half
*
dst
,
const
c10
::
Half
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
Half
,
4
>
(
c10
::
Half
*
dst
,
const
c10
::
Half
*
src
)
{
*
((
float2
*
)
dst
)
=
*
((
float2
*
)
src
);
}
template
<
>
__device__
__inline__
void
copy_vector
<
uint8_t
,
1
>
(
uint8_t
*
dst
,
const
uint8_t
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
uint8_t
,
4
>
(
uint8_t
*
dst
,
const
uint8_t
*
src
)
{
*
((
half2
*
)
dst
)
=
*
((
half2
*
)
src
);
}
int
log2_ceil
(
int
value
)
{
int
log2_ceil
(
int
value
)
{
int
log2_value
=
0
;
int
log2_value
=
0
;
while
((
1
<<
log2_value
)
<
value
)
++
log2_value
;
while
((
1
<<
log2_value
)
<
value
)
++
log2_value
;
...
@@ -90,13 +111,14 @@ __global__ void scaled_masked_softmax_warp_forward(
...
@@ -90,13 +111,14 @@ __global__ void scaled_masked_softmax_warp_forward(
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
ELEMENTS_PER_LDG_STG
=
4
;
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
// gridDim/blockIdx = (seq_len, attn_heads, batches)
int
first_batch
=
(
blockDim
.
y
*
(
blockIdx
.
x
+
gridDim
.
x
*
(
blockIdx
.
y
+
gridDim
.
y
*
blockIdx
.
z
))
+
threadIdx
.
y
)
*
WARP_BATCH
;
int
first_batch
=
(
blockDim
.
y
*
(
blockIdx
.
x
+
gridDim
.
x
*
(
blockIdx
.
y
+
gridDim
.
y
*
blockIdx
.
z
))
+
threadIdx
.
y
)
*
WARP_BATCH
;
int
pad_first_batch
=
0
;
int
pad_first_batch
=
0
;
if
(
pad_batches
!=
1
)
{
// bert style
if
(
pad_batches
!=
1
)
{
// bert style
pad_first_batch
=
(
blockDim
.
y
*
(
blockIdx
.
x
+
gridDim
.
x
*
blockIdx
.
z
)
+
threadIdx
.
y
)
*
WARP_BATCH
;
pad_first_batch
=
(
blockDim
.
y
*
(
blockIdx
.
x
+
gridDim
.
x
*
blockIdx
.
z
)
+
threadIdx
.
y
)
*
WARP_BATCH
;
}
else
{
// gpt2 style
}
else
{
// gpt2 style
pad_first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
WARP_BATCH
;
pad_first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
WARP_BATCH
;
}
}
...
@@ -110,29 +132,40 @@ __global__ void scaled_masked_softmax_warp_forward(
...
@@ -110,29 +132,40 @@ __global__ void scaled_masked_softmax_warp_forward(
// there might be multiple batches per warp. compute the index within the batch
// there might be multiple batches per warp. compute the index within the batch
int
local_idx
=
threadIdx
.
x
;
int
local_idx
=
threadIdx
.
x
;
src
+=
first_batch
*
element_count
+
local_idx
;
src
+=
first_batch
*
element_count
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
dst
+=
first_batch
*
element_count
+
local_idx
;
dst
+=
first_batch
*
element_count
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
mask
+=
pad_first_batch
*
element_count
+
local_idx
;
mask
+=
pad_first_batch
*
element_count
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
// load data from global memory
// load data from global memory
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
input_t
temp_data
[
ELEMENTS_PER_LDG_STG
];
uint8_t
temp_mask
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
#pragma unroll
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
int
itr_idx
=
i
*
element_count
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
if
(
element_index
<
batch_element_count
)
{
if
(
mask
[
itr_idx
]
!=
1
)
{
int
itr_idx
=
i
*
element_count
+
it
*
WARP_SIZE
;
elements
[
i
][
it
]
=
(
acc_t
)
src
[
itr_idx
]
*
scale
;
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_data
,
src
+
itr_idx
);
}
else
{
copy_vector
<
uint8_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_mask
,
mask
+
itr_idx
);
elements
[
i
][
it
]
=
-
10000.0
;
}
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
if
(
temp_mask
[
element
]
!=
1
)
{
elements
[
i
][
it
+
element
]
=
(
acc_t
)
temp_data
[
element
]
*
scale
;
}
else
{
elements
[
i
][
it
+
element
]
=
-
10000.0
;
}
}
}
else
{
}
else
{
elements
[
i
][
it
]
=
-
std
::
numeric_limits
<
acc_t
>::
infinity
();
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
elements
[
i
][
it
+
element
]
=
-
std
::
numeric_limits
<
acc_t
>::
infinity
();
}
}
}
}
}
}
}
...
@@ -161,15 +194,20 @@ __global__ void scaled_masked_softmax_warp_forward(
...
@@ -161,15 +194,20 @@ __global__ void scaled_masked_softmax_warp_forward(
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Add
>
(
sum
);
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Add
>
(
sum
);
// store result
// store result
output_t
out
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
if
(
i
>=
local_batches
)
break
;
break
;
#pragma unroll
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
if
(
element_index
<
element_count
)
{
dst
[
i
*
element_count
+
it
*
WARP_SIZE
]
=
(
output_t
)(
elements
[
i
][
it
]
/
sum
[
i
]);
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
out
[
element
]
=
elements
[
i
][
it
+
element
]
/
sum
[
i
];
}
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
dst
+
i
*
element_count
+
it
*
WARP_SIZE
,
out
);
}
else
{
}
else
{
break
;
break
;
}
}
...
@@ -192,6 +230,7 @@ __global__ void scaled_masked_softmax_warp_backward(
...
@@ -192,6 +230,7 @@ __global__ void scaled_masked_softmax_warp_backward(
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
ELEMENTS_PER_LDG_STG
=
4
;
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
// gridDim/blockIdx = (seq_len, attn_heads, batches)
...
@@ -207,36 +246,36 @@ __global__ void scaled_masked_softmax_warp_backward(
...
@@ -207,36 +246,36 @@ __global__ void scaled_masked_softmax_warp_backward(
int
local_idx
=
threadIdx
.
x
;
int
local_idx
=
threadIdx
.
x
;
// the first element to process by the current thread
// the first element to process by the current thread
int
thread_offset
=
first_batch
*
element_count
+
local_idx
;
int
thread_offset
=
first_batch
*
element_count
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
grad
+=
thread_offset
;
grad
+=
thread_offset
;
output
+=
thread_offset
;
output
+=
thread_offset
;
gradInput
+=
thread_offset
;
gradInput
+=
thread_offset
;
// load data from global memory
// load data from global memory
acc_t
grad_reg
[
WARP_BATCH
][
WARP_ITERATIONS
]
{
0.0
f
};
acc_t
grad_reg
[
WARP_BATCH
][
WARP_ITERATIONS
]
{
0.0
f
};
acc_t
output_reg
[
WARP_BATCH
][
WARP_ITERATIONS
];
acc_t
output_reg
[
WARP_BATCH
][
WARP_ITERATIONS
]
{
0.0
f
};
input_t
temp_grad
[
ELEMENTS_PER_LDG_STG
];
input_t
temp_output
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
#pragma unroll
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
if
(
element_index
<
batch_element_count
)
{
output_reg
[
i
][
it
]
=
output
[
i
*
element_count
+
it
*
WARP_SIZE
];
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_grad
,
grad
+
i
*
element_count
+
it
*
WARP_SIZE
);
}
else
{
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_output
,
output
+
i
*
element_count
+
it
*
WARP_SIZE
);
output_reg
[
i
][
it
]
=
acc_t
(
0
);
}
#pragma unroll
}
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
output_reg
[
i
][
it
+
element
]
=
(
acc_t
)
temp_output
[
element
];
#pragma unroll
}
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
#pragma unroll
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
if
(
element_index
<
batch_element_count
)
{
grad_reg
[
i
][
it
+
element
]
=
(
acc_t
)
temp_grad
[
element
]
*
output_reg
[
i
][
it
+
element
];
grad_reg
[
i
][
it
]
=
(
acc_t
)
grad
[
i
*
element_count
+
it
*
WARP_SIZE
]
*
output_reg
[
i
][
it
];
}
}
else
{
}
grad_reg
[
i
][
it
]
=
acc_t
(
0
);
}
}
}
}
}
...
@@ -257,11 +296,16 @@ __global__ void scaled_masked_softmax_warp_backward(
...
@@ -257,11 +296,16 @@ __global__ void scaled_masked_softmax_warp_backward(
if
(
i
>=
local_batches
)
if
(
i
>=
local_batches
)
break
;
break
;
#pragma unroll
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
if
(
element_index
<
element_count
)
{
// compute gradients
// compute gradients
gradInput
[
i
*
element_count
+
it
*
WARP_SIZE
]
=
(
output_t
)(
scale
*
(
grad_reg
[
i
][
it
]
-
output_reg
[
i
][
it
]
*
sum
[
i
]));
output_t
out
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
out
[
element
]
=
(
output_t
)(
scale
*
(
grad_reg
[
i
][
it
+
element
]
-
output_reg
[
i
][
it
+
element
]
*
sum
[
i
]));
}
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
gradInput
+
i
*
element_count
+
it
*
WARP_SIZE
,
out
);
}
}
}
}
}
}
...
@@ -299,8 +343,8 @@ void dispatch_scaled_masked_softmax_forward(
...
@@ -299,8 +343,8 @@ void dispatch_scaled_masked_softmax_forward(
constexpr
int
threads_per_block
=
128
;
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
TORCH_INTERNAL_ASSERT
(
query_seq_len
%
batches_per_block
==
0
);
TORCH_INTERNAL_ASSERT
(
query_seq_len
%
batches_per_block
==
0
);
dim3
blocks
(
query_seq_len
/
batches_per_block
,
attn_heads
,
batches
);
dim3
blocks
(
query_seq_len
/
batches_per_block
,
attn_heads
,
batches
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
...
@@ -388,7 +432,7 @@ void dispatch_scaled_masked_softmax_backward(
...
@@ -388,7 +432,7 @@ void dispatch_scaled_masked_softmax_backward(
constexpr
int
threads_per_block
=
128
;
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
blocks
=
batch_count
/
batches_per_block
;
int
blocks
=
batch_count
/
batches_per_block
;
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
...
...
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
View file @
ee7b19e7
...
@@ -19,10 +19,10 @@
...
@@ -19,10 +19,10 @@
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <torch/extension.h>
#include "scaled_masked_softmax.h"
#include "scaled_masked_softmax.h"
#include "type_shim.h"
namespace
multihead_attn
{
namespace
multihead_attn
{
namespace
fused_softmax
{
namespace
fused_softmax
{
...
@@ -56,16 +56,20 @@ torch::Tensor fwd_cuda(
...
@@ -56,16 +56,20 @@ torch::Tensor fwd_cuda(
void
*
mask_ptr
=
static_cast
<
void
*>
(
mask
.
data_ptr
());
void
*
mask_ptr
=
static_cast
<
void
*>
(
mask
.
data_ptr
());
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
dispatch_scaled_masked_softmax_forward
<
half
,
half
,
float
>
(
DISPATCH_HALF_AND_BFLOAT
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
input
.
scalar_type
(),
reinterpret_cast
<
const
half
*>
(
input_ptr
),
"dispatch_scaled_masked_softmax_forward"
,
reinterpret_cast
<
const
uint8_t
*>
(
mask_ptr
),
dispatch_scaled_masked_softmax_forward
<
scalar_t
,
scalar_t
,
float
>
(
scale_factor
,
reinterpret_cast
<
scalar_t
*>
(
softmax_results_ptr
),
query_seq_len
,
reinterpret_cast
<
const
scalar_t
*>
(
input_ptr
),
key_seq_len
,
reinterpret_cast
<
const
uint8_t
*>
(
mask_ptr
),
batches
,
scale_factor
,
attn_heads
,
query_seq_len
,
pad_batches
);
key_seq_len
,
batches
,
attn_heads
,
pad_batches
);
);
return
softmax_results
;
return
softmax_results
;
}
}
...
@@ -86,15 +90,19 @@ torch::Tensor bwd_cuda(
...
@@ -86,15 +90,19 @@ torch::Tensor bwd_cuda(
void
*
output_grads_ptr
=
static_cast
<
void
*>
(
output_grads
.
data_ptr
());
void
*
output_grads_ptr
=
static_cast
<
void
*>
(
output_grads
.
data_ptr
());
//Softmax Grad
//Softmax Grad
dispatch_scaled_masked_softmax_backward
<
half
,
half
,
float
>
(
DISPATCH_HALF_AND_BFLOAT
(
reinterpret_cast
<
half
*>
(
output_grads_ptr
),
output_grads_
.
scalar_type
(),
reinterpret_cast
<
half
*>
(
output_grads_ptr
),
"dispatch_scaled_masked_softmax_backward"
,
reinterpret_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
dispatch_scaled_masked_softmax_backward
<
scalar_t
,
scalar_t
,
float
>
(
scale_factor
,
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
query_seq_len
,
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
key_seq_len
,
reinterpret_cast
<
scalar_t
const
*>
(
softmax_results
.
data_ptr
()),
batches
,
scale_factor
,
attn_heads
);
query_seq_len
,
key_seq_len
,
batches
,
attn_heads
);
);
//backward pass is completely in-place
//backward pass is completely in-place
return
output_grads
;
return
output_grads
;
...
...
megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp
View file @
ee7b19e7
...
@@ -33,8 +33,9 @@ torch::Tensor bwd_cuda(
...
@@ -33,8 +33,9 @@ torch::Tensor bwd_cuda(
torch
::
Tensor
fwd
(
torch
::
Tensor
const
&
input
,
float
scale_factor
)
{
torch
::
Tensor
fwd
(
torch
::
Tensor
const
&
input
,
float
scale_factor
)
{
AT_ASSERTM
(
input
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
input
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
,
AT_ASSERTM
((
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
"Only HALF is supported"
);
(
input
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
return
fwd_cuda
(
input
,
scale_factor
);
return
fwd_cuda
(
input
,
scale_factor
);
}
}
...
@@ -47,10 +48,12 @@ torch::Tensor bwd(
...
@@ -47,10 +48,12 @@ torch::Tensor bwd(
AT_ASSERTM
(
output_grads
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
output_grads
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
Half
,
AT_ASSERTM
((
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
"Only HALF is supported"
);
(
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
AT_ASSERTM
(
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
Half
,
"Only fp16 and bf16 are supported"
);
"Only HALF is supported"
);
AT_ASSERTM
((
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
return
bwd_cuda
(
output_grads
,
softmax_results
,
scale_factor
);
return
bwd_cuda
(
output_grads
,
softmax_results
,
scale_factor
);
}
}
...
@@ -61,7 +64,7 @@ torch::Tensor bwd(
...
@@ -61,7 +64,7 @@ torch::Tensor bwd(
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
m
.
def
(
"forward"
,
&
multihead_attn
::
fused_softmax
::
scaled_upper_triang_masked_softmax
::
fwd
,
&
multihead_attn
::
fused_softmax
::
scaled_upper_triang_masked_softmax
::
fwd
,
"Self Multihead Attention scaled, time masked softmax -- Forward."
);
"Self Multihead Attention scaled, time masked softmax -- Forward."
);
m
.
def
(
"backward"
,
m
.
def
(
"backward"
,
&
multihead_attn
::
fused_softmax
::
scaled_upper_triang_masked_softmax
::
bwd
,
&
multihead_attn
::
fused_softmax
::
scaled_upper_triang_masked_softmax
::
bwd
,
...
...
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
View file @
ee7b19e7
...
@@ -21,11 +21,47 @@
...
@@ -21,11 +21,47 @@
#include <cfloat>
#include <cfloat>
#include <limits>
#include <limits>
#include <stdint.h>
#include <stdint.h>
#include <cuda_fp16.h>
#include <c10/macros/Macros.h>
#include <c10/macros/Macros.h>
namespace
{
namespace
{
template
<
typename
Datatype
,
int
ELEMENTS_PER_LDG
>
__device__
__inline__
void
copy_vector
(
Datatype
*
dst
,
const
Datatype
*
src
);
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
BFloat16
,
1
>
(
c10
::
BFloat16
*
dst
,
const
c10
::
BFloat16
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
BFloat16
,
4
>
(
c10
::
BFloat16
*
dst
,
const
c10
::
BFloat16
*
src
)
{
*
((
float2
*
)
dst
)
=
*
((
float2
*
)
src
);
}
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
Half
,
1
>
(
c10
::
Half
*
dst
,
const
c10
::
Half
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
Half
,
4
>
(
c10
::
Half
*
dst
,
const
c10
::
Half
*
src
)
{
*
((
float2
*
)
dst
)
=
*
((
float2
*
)
src
);
}
template
<
>
__device__
__inline__
void
copy_vector
<
uint8_t
,
1
>
(
uint8_t
*
dst
,
const
uint8_t
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
uint8_t
,
4
>
(
uint8_t
*
dst
,
const
uint8_t
*
src
)
{
*
((
half2
*
)
dst
)
=
*
((
half2
*
)
src
);
}
template
<
typename
Datatype
,
int
ELEMENTS_PER_LDG
>
__device__
__inline__
void
copy_zero_vector
(
Datatype
*
dst
);
template
<
>
__device__
__inline__
void
copy_zero_vector
<
c10
::
BFloat16
,
1
>
(
c10
::
BFloat16
*
dst
)
{
*
dst
=
0.0
;
}
template
<
>
__device__
__inline__
void
copy_zero_vector
<
c10
::
BFloat16
,
4
>
(
c10
::
BFloat16
*
dst
)
{
*
((
float2
*
)
dst
)
=
make_float2
(
0.0
f
,
0.0
f
);
}
template
<
>
__device__
__inline__
void
copy_zero_vector
<
c10
::
Half
,
1
>
(
c10
::
Half
*
dst
)
{
*
dst
=
0.0
;
}
template
<
>
__device__
__inline__
void
copy_zero_vector
<
c10
::
Half
,
4
>
(
c10
::
Half
*
dst
)
{
*
((
float2
*
)
dst
)
=
make_float2
(
0.0
f
,
0.0
f
);
}
int
log2_ceil
(
int
value
)
{
int
log2_ceil
(
int
value
)
{
int
log2_value
=
0
;
int
log2_value
=
0
;
while
((
1
<<
log2_value
)
<
value
)
++
log2_value
;
while
((
1
<<
log2_value
)
<
value
)
++
log2_value
;
...
@@ -73,7 +109,7 @@ __device__ __forceinline__ void warp_reduce(acc_t* sum) {
...
@@ -73,7 +109,7 @@ __device__ __forceinline__ void warp_reduce(acc_t* sum) {
* Extended softmax (from native aten pytorch) with following additional features
* Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling
* 1) input scaling
* 2) Implicit time (diagonal masking)
* 2) Implicit time (diagonal masking)
*/
*/
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
log2_elements
>
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
log2_elements
>
__global__
void
scaled_upper_triang_masked_softmax_warp_forward
(
__global__
void
scaled_upper_triang_masked_softmax_warp_forward
(
output_t
*
dst
,
output_t
*
dst
,
...
@@ -89,10 +125,11 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
...
@@ -89,10 +125,11 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
ELEMENTS_PER_LDG_STG
=
4
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
y
+
threadIdx
.
y
)
*
gridDim
.
x
*
WARP_BATCH
+
blockIdx
.
x
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
y
+
threadIdx
.
y
)
*
gridDim
.
x
*
WARP_BATCH
+
blockIdx
.
x
;
int
local_seq
=
blockIdx
.
x
+
1
;
int
local_seq
=
blockIdx
.
x
+
1
;
int
warp_iteration_limit
=
(
local_seq
+
WARP_SIZE
-
1
)
/
WARP_SIZE
;
int
warp_iteration_limit
=
(
local_seq
+
ELEMENTS_PER_LDG_STG
*
WARP_SIZE
-
1
)
/
WARP_SIZE
;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
// many batches have to computed within this WARP.
...
@@ -103,22 +140,36 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
...
@@ -103,22 +140,36 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
// there might be multiple batches per warp. compute the index within the batch
// there might be multiple batches per warp. compute the index within the batch
int
local_idx
=
threadIdx
.
x
;
int
local_idx
=
threadIdx
.
x
;
src
+=
first_batch
*
stride
+
local_idx
;
src
+=
first_batch
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
dst
+=
first_batch
*
stride
+
local_idx
;
dst
+=
first_batch
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
// load data from global memory
// load data from global memory
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
input_t
temp_data
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
local_seq
;
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
local_seq
;
#pragma unroll
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
if
(
element_index
<
batch_element_count
)
{
elements
[
i
][
it
]
=
(
acc_t
)
src
[
i
*
element_count
*
stride
+
it
*
WARP_SIZE
]
*
scale
;
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_data
,
src
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
);
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
if
((
element_index
+
element
)
<
batch_element_count
)
{
elements
[
i
][
it
+
element
]
=
(
acc_t
)
temp_data
[
element
]
*
scale
;
}
else
{
elements
[
i
][
it
+
element
]
=
-
std
::
numeric_limits
<
acc_t
>::
infinity
();
}
}
}
else
{
}
else
{
elements
[
i
][
it
]
=
-
std
::
numeric_limits
<
acc_t
>::
infinity
();
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
elements
[
i
][
it
+
element
]
=
-
std
::
numeric_limits
<
acc_t
>::
infinity
();
}
}
}
}
}
}
}
...
@@ -140,26 +191,37 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
...
@@ -140,26 +191,37 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
#pragma unroll
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
if
(
it
<
warp_iteration_limit
)
{
if
(
it
<
warp_iteration_limit
)
{
elements
[
i
][
it
]
=
std
::
exp
((
elements
[
i
][
it
]
-
max_value
[
i
]));
elements
[
i
][
it
]
=
std
::
exp
((
elements
[
i
][
it
]
-
max_value
[
i
]));
sum
[
i
]
+=
elements
[
i
][
it
];
sum
[
i
]
+=
elements
[
i
][
it
];
}
}
}
}
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Add
>
(
sum
);
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Add
>
(
sum
);
// store result
// store result
output_t
out
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
if
(
i
>=
local_batches
)
break
;
break
;
#pragma unroll
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
local_seq
)
{
if
(
element_index
<
local_seq
)
{
dst
[
i
*
element_count
*
stride
+
it
*
WARP_SIZE
]
=
(
output_t
)(
elements
[
i
][
it
]
/
sum
[
i
]);
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
if
(
element_index
+
element
<
local_seq
)
{
out
[
element
]
=
elements
[
i
][
it
+
element
]
/
sum
[
i
];
}
else
{
out
[
element
]
=
0
;
}
}
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
dst
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
,
out
);
}
else
if
(
element_index
<
element_count
)
{
}
else
if
(
element_index
<
element_count
)
{
dst
[
i
*
element_count
*
stride
+
it
*
WARP_SIZE
]
=
0
;
copy_zero_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
dst
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
)
;
}
else
{
}
else
{
break
;
break
;
}
}
...
@@ -183,6 +245,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
...
@@ -183,6 +245,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
ELEMENTS_PER_LDG_STG
=
4
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
y
+
threadIdx
.
y
)
*
gridDim
.
x
*
WARP_BATCH
+
blockIdx
.
x
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
y
+
threadIdx
.
y
)
*
gridDim
.
x
*
WARP_BATCH
+
blockIdx
.
x
;
int
local_seq
=
blockIdx
.
x
+
1
;
int
local_seq
=
blockIdx
.
x
+
1
;
...
@@ -197,37 +260,41 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
...
@@ -197,37 +260,41 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
int
local_idx
=
threadIdx
.
x
;
int
local_idx
=
threadIdx
.
x
;
// the first element to process by the current thread
// the first element to process by the current thread
int
thread_offset
=
first_batch
*
stride
+
local_idx
;
int
thread_offset
=
first_batch
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
grad
+=
thread_offset
;
grad
+=
thread_offset
;
output
+=
thread_offset
;
output
+=
thread_offset
;
gradInput
+=
thread_offset
;
gradInput
+=
thread_offset
;
// load data from global memory
// load data from global memory
acc_t
grad_reg
[
WARP_BATCH
][
WARP_ITERATIONS
]
{
0.0
f
};
acc_t
grad_reg
[
WARP_BATCH
][
WARP_ITERATIONS
]
{
0.0
f
};
acc_t
output_reg
[
WARP_BATCH
][
WARP_ITERATIONS
];
acc_t
output_reg
[
WARP_BATCH
][
WARP_ITERATIONS
]
{
0.0
f
};
input_t
temp_grad
[
ELEMENTS_PER_LDG_STG
];
input_t
temp_output
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
local_seq
;
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
local_seq
;
#pragma unroll
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
if
(
element_index
<
batch_element_count
)
{
output_reg
[
i
][
it
]
=
output
[
i
*
element_count
*
stride
+
it
*
WARP_SIZE
];
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_grad
,
grad
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
);
}
else
{
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_output
,
output
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
);
output_reg
[
i
][
it
]
=
acc_t
(
0
);
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
if
(
element_index
+
element
<
batch_element_count
)
{
output_reg
[
i
][
it
+
element
]
=
(
acc_t
)
temp_output
[
element
];
}
}
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
if
(
element_index
+
element
<
batch_element_count
)
{
grad_reg
[
i
][
it
+
element
]
=
(
acc_t
)
temp_grad
[
element
]
*
output_reg
[
i
][
it
+
element
];
}
}
}
}
}
}
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
grad_reg
[
i
][
it
]
=
(
acc_t
)
grad
[
i
*
element_count
*
stride
+
it
*
WARP_SIZE
]
*
output_reg
[
i
][
it
];
}
else
{
grad_reg
[
i
][
it
]
=
acc_t
(
0
);
}
}
}
}
acc_t
sum
[
WARP_BATCH
];
acc_t
sum
[
WARP_BATCH
];
...
@@ -247,11 +314,16 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
...
@@ -247,11 +314,16 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
if
(
i
>=
local_batches
)
if
(
i
>=
local_batches
)
break
;
break
;
#pragma unroll
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
if
(
element_index
<
element_count
)
{
// compute gradients
// compute gradients
gradInput
[
i
*
element_count
*
stride
+
it
*
WARP_SIZE
]
=
(
output_t
)(
scale
*
(
grad_reg
[
i
][
it
]
-
output_reg
[
i
][
it
]
*
sum
[
i
]));
output_t
out
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
out
[
element
]
=
(
output_t
)(
scale
*
(
grad_reg
[
i
][
it
+
element
]
-
output_reg
[
i
][
it
+
element
]
*
sum
[
i
]));
}
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
gradInput
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
,
out
);
}
}
}
}
}
}
...
...
megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu
View file @
ee7b19e7
...
@@ -19,10 +19,10 @@
...
@@ -19,10 +19,10 @@
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <torch/extension.h>
#include "scaled_upper_triang_masked_softmax.h"
#include "scaled_upper_triang_masked_softmax.h"
#include "type_shim.h"
namespace
multihead_attn
{
namespace
multihead_attn
{
namespace
fused_softmax
{
namespace
fused_softmax
{
...
@@ -46,15 +46,20 @@ torch::Tensor fwd_cuda(
...
@@ -46,15 +46,20 @@ torch::Tensor fwd_cuda(
void
*
input_ptr
=
static_cast
<
void
*>
(
input
.
data_ptr
());
void
*
input_ptr
=
static_cast
<
void
*>
(
input
.
data_ptr
());
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
dispatch_scaled_upper_triang_masked_softmax_forward
<
half
,
half
,
float
>
(
DISPATCH_HALF_AND_BFLOAT
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
input
.
scalar_type
(),
reinterpret_cast
<
const
half
*>
(
input_ptr
),
"dispatch_scaled_upper_triang_masked_softmax_forward"
,
scale_factor
,
dispatch_scaled_upper_triang_masked_softmax_forward
<
scalar_t
,
scalar_t
,
float
>
(
seq_len
,
reinterpret_cast
<
scalar_t
*>
(
softmax_results_ptr
),
seq_len
,
reinterpret_cast
<
const
scalar_t
*>
(
input_ptr
),
attn_batches
);
scale_factor
,
seq_len
,
seq_len
,
attn_batches
);
);
return
softmax_results
;
return
softmax_results
;
}
}
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
const
&
output_grads_
,
torch
::
Tensor
const
&
output_grads_
,
...
@@ -72,14 +77,18 @@ torch::Tensor bwd_cuda(
...
@@ -72,14 +77,18 @@ torch::Tensor bwd_cuda(
void
*
output_grads_ptr
=
static_cast
<
void
*>
(
output_grads
.
data_ptr
());
void
*
output_grads_ptr
=
static_cast
<
void
*>
(
output_grads
.
data_ptr
());
//Softmax Grad
//Softmax Grad
dispatch_scaled_upper_triang_masked_softmax_backward
<
half
,
half
,
float
>
(
DISPATCH_HALF_AND_BFLOAT
(
reinterpret_cast
<
half
*>
(
output_grads_ptr
),
output_grads_
.
scalar_type
(),
reinterpret_cast
<
half
*>
(
output_grads_ptr
),
"dispatch_scaled_upper_triang_masked_softmax_backward"
,
reinterpret_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
dispatch_scaled_upper_triang_masked_softmax_backward
<
scalar_t
,
scalar_t
,
float
>
(
scale_factor
,
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
seq_len
,
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
seq_len
,
reinterpret_cast
<
scalar_t
const
*>
(
softmax_results
.
data_ptr
()),
attn_batches
);
scale_factor
,
seq_len
,
seq_len
,
attn_batches
);
);
//backward pass is completely in-place
//backward pass is completely in-place
return
output_grads
;
return
output_grads
;
...
...
megatron/fused_kernels/type_shim.h
View file @
ee7b19e7
...
@@ -14,214 +14,78 @@
...
@@ -14,214 +14,78 @@
* limitations under the License.
* limitations under the License.
*/
*/
/*This code is copied fron NVIDIA apex:
* https://github.com/NVIDIA/apex
* with minor changes. */
#include <ATen/ATen.h>
#include <ATen/ATen.h>
#include "compat.h"
#include "compat.h"
// Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream.
// struct TypeShim
// {
// const at::Type& payload;
// TypeShim(const at::Type& type) : payload(type) {}
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
// operator const at::Type&(){ return payload; };
// // Enable dispatch switch statements to take *this directly for post-3aeb78
// //operator at::ScalarType(){ return payload.; };
// };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
switch(TYPE) \
switch(TYPE) \
{ \
{ \
case at::ScalarType::Float: \
case at::ScalarType::Half: \
{ \
{ \
using scalar_t_##LEVEL = float; \
using scalar_t = at::Half; \
__VA_ARGS__; \
__VA_ARGS__; \
break; \
break; \
} \
} \
case at::ScalarType::Half: \
case at::ScalarType::BFloat16: \
{ \
{ \
using scalar_t_##LEVEL = at::Half; \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
__VA_ARGS__; \
break; \
break; \
} \
} \
default: \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch(TYPEIN) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_in = float; \
switch(TYPEOUT) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
} \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_in = at::Half; \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_in = at::BFloat16; \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Byte: \
{ \
using scalar_t_##LEVEL = uint8_t; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template
<
typename
T
>
__device__
__forceinline__
T
reduce_block_into_lanes
(
T
*
x
,
T
val
,
int
lanes
=
1
,
bool
share_result
=
false
)
// lanes is intended to be <= 32.
{
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
int
blockSize
=
blockDim
.
x
*
blockDim
.
y
;
// blockSize is intended to be a multiple of 32.
if
(
blockSize
>=
64
)
{
x
[
tid
]
=
val
;
__syncthreads
();
}
#pragma unroll
for
(
int
i
=
(
blockSize
>>
1
);
i
>=
64
;
i
>>=
1
)
{
if
(
tid
<
i
)
x
[
tid
]
=
x
[
tid
]
+
x
[
tid
+
i
];
__syncthreads
();
}
T
final
;
if
(
tid
<
32
)
{
if
(
blockSize
>=
64
)
final
=
x
[
tid
]
+
x
[
tid
+
32
];
else
final
=
val
;
// __SYNCWARP();
#pragma unroll
for
(
int
i
=
16
;
i
>=
lanes
;
i
>>=
1
)
final
=
final
+
__shfl_down_sync
(
0xffffffff
,
final
,
i
);
}
if
(
share_result
)
{
if
(
tid
<
lanes
)
x
[
tid
]
=
final
;
// EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads
();
}
return
final
;
}
template
<
typename
T
>
__device__
__forceinline__
T
reduce_block_into_lanes_max_op
(
T
*
x
,
T
val
,
int
lanes
=
1
,
bool
share_result
=
false
)
// lanes is intended to be <= 32.
{
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
int
blockSize
=
blockDim
.
x
*
blockDim
.
y
;
// blockSize is intended to be a multiple of 32.
if
(
blockSize
>=
64
)
{
x
[
tid
]
=
val
;
__syncthreads
();
}
#pragma unroll
for
(
int
i
=
(
blockSize
>>
1
);
i
>=
64
;
i
>>=
1
)
{
if
(
tid
<
i
)
x
[
tid
]
=
fmaxf
(
fabsf
(
x
[
tid
]),
fabsf
(
x
[
tid
+
i
]));
__syncthreads
();
}
T
final
;
if
(
tid
<
32
)
{
if
(
blockSize
>=
64
)
final
=
fmaxf
(
fabsf
(
x
[
tid
]),
fabsf
(
x
[
tid
+
32
]));
else
final
=
val
;
// __SYNCWARP();
#pragma unroll
for
(
int
i
=
16
;
i
>=
lanes
;
i
>>=
1
)
final
=
fmaxf
(
fabsf
(
final
),
fabsf
(
__shfl_down_sync
(
0xffffffff
,
final
,
i
)));
}
if
(
share_result
)
{
if
(
tid
<
lanes
)
x
[
tid
]
=
final
;
// EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads
();
}
return
final
;
}
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment