Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
aed2f75e
Commit
aed2f75e
authored
Apr 11, 2021
by
Jared Casper
Browse files
Merge branch 'main' into github-main
parents
8aa4619f
f32a638d
Changes
96
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1384 additions
and
861 deletions
+1384
-861
megatron/data/ict_dataset.py
megatron/data/ict_dataset.py
+20
-4
megatron/data/orqa_wiki_dataset.py
megatron/data/orqa_wiki_dataset.py
+205
-0
megatron/data/realm_index.py
megatron/data/realm_index.py
+88
-80
megatron/data/vit_dataset.py
megatron/data/vit_dataset.py
+58
-0
megatron/fused_kernels/__init__.py
megatron/fused_kernels/__init__.py
+70
-87
megatron/fused_kernels/layer_norm_cuda.cpp
megatron/fused_kernels/layer_norm_cuda.cpp
+32
-91
megatron/fused_kernels/layer_norm_cuda_kernel.cu
megatron/fused_kernels/layer_norm_cuda_kernel.cu
+33
-33
megatron/fused_kernels/scaled_masked_softmax.cpp
megatron/fused_kernels/scaled_masked_softmax.cpp
+9
-6
megatron/fused_kernels/scaled_masked_softmax.h
megatron/fused_kernels/scaled_masked_softmax.h
+123
-83
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
+37
-27
megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp
...tron/fused_kernels/scaled_upper_triang_masked_softmax.cpp
+10
-7
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
+109
-37
megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu
.../fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu
+25
-16
megatron/fused_kernels/type_shim.h
megatron/fused_kernels/type_shim.h
+69
-205
megatron/global_vars.py
megatron/global_vars.py
+5
-3
megatron/indexer.py
megatron/indexer.py
+79
-43
megatron/initialize.py
megatron/initialize.py
+73
-16
megatron/model/__init__.py
megatron/model/__init__.py
+5
-29
megatron/model/bert_model.py
megatron/model/bert_model.py
+39
-94
megatron/model/biencoder_model.py
megatron/model/biencoder_model.py
+295
-0
No files found.
megatron/data/ict_dataset.py
View file @
aed2f75e
...
...
@@ -9,6 +9,16 @@ from megatron import get_args
from
megatron.data.dataset_utils
import
get_indexed_dataset_
from
megatron.data.realm_dataset_utils
import
get_block_samples_mapping
def
make_attention_mask
(
source_block
,
target_block
):
"""
Returns a 2-dimensional (2-D) attention mask
:param source_block: 1-D array
:param target_block: 1-D array
"""
mask
=
(
target_block
[
None
,
:]
>=
1
)
*
(
source_block
[:,
None
]
>=
1
)
mask
=
mask
.
astype
(
np
.
int64
)
# (source_length, target_length)
return
mask
def
get_ict_dataset
(
use_titles
=
True
,
query_in_block_prob
=
1
):
"""Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block())
...
...
@@ -39,7 +49,7 @@ class ICTDataset(Dataset):
"""Dataset containing sentences and their blocks for an inverse cloze task."""
def
__init__
(
self
,
name
,
block_dataset
,
title_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
,
max_seq_length
,
query_in_block_prob
,
seed
,
use_titles
=
True
,
use_one_sent_docs
=
False
):
seed
,
use_titles
=
True
,
use_one_sent_docs
=
False
,
binary_head
=
False
):
self
.
name
=
name
self
.
seed
=
seed
self
.
max_seq_length
=
max_seq_length
...
...
@@ -93,14 +103,20 @@ class ICTDataset(Dataset):
block
=
list
(
itertools
.
chain
(
*
block
))[:
self
.
max_seq_length
-
title_pad_offset
]
query_tokens
,
query_pad_mask
=
self
.
concat_and_pad_tokens
(
query
)
block_tokens
,
block_pad_mask
=
self
.
concat_and_pad_tokens
(
block
,
title
)
context_tokens
,
context_pad_mask
=
self
.
concat_and_pad_tokens
(
block
,
title
)
query_mask
=
make_attention_mask
(
query_tokens
,
query_tokens
)
context_mask
=
make_attention_mask
(
context_tokens
,
context_tokens
)
block_data
=
sample_data
.
as_array
()
sample
=
{
'query_tokens'
:
query_tokens
,
'query_mask'
:
query_mask
,
'query_pad_mask'
:
query_pad_mask
,
'block_tokens'
:
block_tokens
,
'block_pad_mask'
:
block_pad_mask
,
'context_tokens'
:
context_tokens
,
'context_mask'
:
context_mask
,
'context_pad_mask'
:
context_pad_mask
,
'block_data'
:
block_data
,
}
...
...
megatron/data/orqa_wiki_dataset.py
0 → 100644
View file @
aed2f75e
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wikipedia dataset from DPR code for ORQA."""
from
abc
import
ABC
import
csv
import
numpy
as
np
import
random
import
torch
from
torch.utils.data
import
Dataset
from
megatron
import
print_rank_0
,
get_args
,
get_tokenizer
,
mpu
from
megatron.data.biencoder_dataset_utils
import
make_attention_mask
def
get_open_retrieval_wiki_dataset
():
args
=
get_args
()
tokenizer
=
get_tokenizer
()
dataset
=
OpenRetrievalEvidenceDataset
(
'2018 Wikipedia from DPR codebase'
,
'evidence'
,
args
.
evidence_data_path
,
tokenizer
,
args
.
retriever_seq_length
)
return
dataset
def
get_open_retrieval_batch
(
data_iterator
):
# Items and their type.
keys
=
[
'row_id'
,
'context'
,
'context_mask'
,
'context_types'
,
'context_pad_mask'
]
datatype
=
torch
.
int64
# Broadcast data.
data
=
None
if
data_iterator
is
None
else
next
(
data_iterator
)
data_b
=
mpu
.
broadcast_data
(
keys
,
data
,
datatype
)
# Unpack.
row_id
=
data_b
[
'row_id'
].
long
()
context
=
data_b
[
'context'
].
long
()
# TODO: make the context mask a binary one
context_mask
=
(
data_b
[
'context_mask'
]
<
0.5
)
context_types
=
data_b
[
'context_types'
].
long
()
context_pad_mask
=
data_b
[
'context_pad_mask'
].
long
()
return
row_id
,
context
,
context_mask
,
context_types
,
context_pad_mask
def
build_tokens_types_paddings_from_text
(
row
,
tokenizer
,
max_seq_length
):
"""Build token types and paddings, trim if needed, and pad if needed."""
title_ids
=
tokenizer
.
tokenize
(
row
[
'title'
])
context_ids
=
tokenizer
.
tokenize
(
row
[
'text'
])
# Appending the title of the context at front
extended_context_ids
=
title_ids
+
[
tokenizer
.
sep_id
]
+
context_ids
context_ids
,
context_types
,
context_pad_mask
=
\
build_tokens_types_paddings_from_ids
(
extended_context_ids
,
max_seq_length
,
tokenizer
.
cls
,
tokenizer
.
sep
,
tokenizer
.
pad
)
return
context_ids
,
context_types
,
context_pad_mask
# noinspection DuplicatedCode
def
build_tokens_types_paddings_from_ids
(
text_ids
,
max_seq_length
,
cls_id
,
sep_id
,
pad_id
):
"""Build token types and paddings, trim if needed, and pad if needed."""
enc_ids
=
[]
tokentypes_enc
=
[]
# [CLS].
enc_ids
.
append
(
cls_id
)
tokentypes_enc
.
append
(
0
)
# A.
len_src
=
len
(
text_ids
)
enc_ids
.
extend
(
text_ids
)
tokentypes_enc
.
extend
([
0
]
*
len_src
)
# Cap the size.
if
len
(
enc_ids
)
>
max_seq_length
-
1
:
enc_ids
=
enc_ids
[
0
:
max_seq_length
-
1
]
tokentypes_enc
=
tokentypes_enc
[
0
:
max_seq_length
-
1
]
# [SEP].
enc_ids
.
append
(
sep_id
)
tokentypes_enc
.
append
(
0
)
num_tokens_enc
=
len
(
enc_ids
)
# Padding.
padding_length
=
max_seq_length
-
len
(
enc_ids
)
if
padding_length
>
0
:
enc_ids
.
extend
([
pad_id
]
*
padding_length
)
tokentypes_enc
.
extend
([
pad_id
]
*
padding_length
)
pad_mask
=
([
1
]
*
num_tokens_enc
)
+
([
0
]
*
padding_length
)
pad_mask
=
np
.
array
(
pad_mask
,
dtype
=
np
.
int64
)
return
enc_ids
,
tokentypes_enc
,
pad_mask
def
build_sample
(
row_id
,
context_ids
,
context_types
,
context_pad_mask
):
"""Convert to numpy and return a sample consumed by the batch producer."""
context_ids
=
np
.
array
(
context_ids
,
dtype
=
np
.
int64
)
context_types
=
np
.
array
(
context_types
,
dtype
=
np
.
int64
)
context_mask
=
make_attention_mask
(
context_ids
,
context_ids
)
sample
=
({
'row_id'
:
row_id
,
'context'
:
context_ids
,
'context_mask'
:
context_mask
,
'context_types'
:
context_types
,
'context_pad_mask'
:
context_pad_mask
})
return
sample
class
OpenRetrievalEvidenceDataset
(
ABC
,
Dataset
):
"""Open Retrieval Evidence dataset class."""
def
__init__
(
self
,
task_name
,
dataset_name
,
datapath
,
tokenizer
,
max_seq_length
):
# Store inputs.
self
.
task_name
=
task_name
self
.
dataset_name
=
dataset_name
self
.
tokenizer
=
tokenizer
self
.
max_seq_length
=
max_seq_length
print_rank_0
(
' > building {} dataset for {}:'
.
format
(
self
.
task_name
,
self
.
dataset_name
))
# Process the files.
print_rank_0
(
datapath
)
self
.
samples
,
self
.
id2text
=
self
.
process_samples_from_single_path
(
datapath
)
args
=
get_args
()
if
args
.
sample_rate
<
1
:
# subsample
k
=
int
(
len
(
self
.
samples
)
*
args
.
sample_rate
)
self
.
samples
=
random
.
sample
(
self
.
samples
,
k
)
print_rank_0
(
' >> total number of samples: {}'
.
format
(
len
(
self
.
samples
)))
def
__len__
(
self
):
return
len
(
self
.
samples
)
def
__getitem__
(
self
,
idx
):
row
=
self
.
samples
[
idx
]
context_ids
,
context_types
,
context_pad_mask
=
\
build_tokens_types_paddings_from_text
(
row
,
self
.
tokenizer
,
self
.
max_seq_length
)
sample
=
build_sample
(
row
[
'doc_id'
],
context_ids
,
context_types
,
context_pad_mask
)
return
sample
@
staticmethod
def
process_samples_from_single_path
(
filename
):
print_rank_0
(
' > Processing {} ...'
.
format
(
filename
))
total
=
0
rows
=
[]
id2text
=
{}
with
open
(
filename
)
as
tsvfile
:
reader
=
csv
.
reader
(
tsvfile
,
delimiter
=
'
\t
'
)
next
(
reader
,
None
)
# skip the headers
for
row
in
reader
:
# file format: doc_id, doc_text, title
doc_id
=
int
(
row
[
0
])
text
=
row
[
1
]
title
=
row
[
2
]
rows
.
append
({
'doc_id'
:
doc_id
,
'text'
:
text
,
'title'
:
title
})
assert
doc_id
not
in
id2text
id2text
[
doc_id
]
=
(
text
,
title
)
total
+=
1
if
total
%
100000
==
0
:
print_rank_0
(
' > processed {} rows so far ...'
.
format
(
total
))
print_rank_0
(
' >> processed {} samples.'
.
format
(
len
(
rows
)))
return
rows
,
id2text
megatron/data/realm_index.py
View file @
aed2f75e
...
...
@@ -14,34 +14,36 @@ def detach(tensor):
return
tensor
.
detach
().
cpu
().
numpy
()
class
BlockData
(
object
):
"""Serializable data structure for holding data for blocks -- embeddings and necessary metadata for REALM"""
def
__init__
(
self
,
block_data_path
=
None
,
load_from_path
=
True
,
rank
=
None
):
class
OpenRetreivalDataStore
(
object
):
"""
Serializable data structure for holding data for blocks --
embeddings and necessary metadata for Retriever
"""
def
__init__
(
self
,
embedding_path
=
None
,
load_from_path
=
True
,
rank
=
None
):
self
.
embed_data
=
dict
()
self
.
meta_data
=
dict
()
if
block_data_path
is
None
:
if
embedding_path
is
None
:
args
=
get_args
()
block_data
_path
=
args
.
block_data
_path
embedding
_path
=
args
.
embedding
_path
rank
=
args
.
rank
self
.
block_data_path
=
block_data
_path
self
.
embedding_path
=
embedding
_path
self
.
rank
=
rank
if
load_from_path
:
self
.
load_from_file
()
block_data_name
=
os
.
path
.
splitext
(
self
.
block_data
_path
)[
0
]
block_data_name
=
os
.
path
.
splitext
(
self
.
embedding
_path
)[
0
]
self
.
temp_dir_name
=
block_data_name
+
'_tmp'
def
state
(
self
):
return
{
'embed_data'
:
self
.
embed_data
,
'meta_data'
:
self
.
meta_data
,
}
def
clear
(
self
):
"""Clear the embedding data structures to save memory.
The metadata ends up getting used, and is also much smaller in dimensionality
so it isn't really worth clearing.
"""
Clear the embedding data structures to save memory.
The metadata ends up getting used, and is also much smaller in
dimensionality so it isn't really worth clearing.
"""
self
.
embed_data
=
dict
()
...
...
@@ -50,38 +52,39 @@ class BlockData(object):
if
mpu
.
is_unitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
"
\n
> Unpickling BlockData"
,
flush
=
True
)
state_dict
=
pickle
.
load
(
open
(
self
.
block_data
_path
,
'rb'
))
state_dict
=
pickle
.
load
(
open
(
self
.
embedding
_path
,
'rb'
))
if
mpu
.
is_unitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
">> Finished unpickling BlockData
\n
"
,
flush
=
True
)
self
.
embed_data
=
state_dict
[
'embed_data'
]
self
.
meta_data
=
state_dict
[
'meta_data'
]
def
add_block_data
(
self
,
block_indices
,
block_embeds
,
block_metas
,
allow_overwrite
=
False
):
"""Add data for set of blocks
:param block_indices: 1D array of unique int ids for the blocks
def
add_block_data
(
self
,
row_id
,
block_embeds
,
allow_overwrite
=
False
):
"""
Add data for set of blocks
:param row_id: 1D array of unique int ids for the blocks
:param block_embeds: 2D array of embeddings of the blocks
:param block_metas: 2D array of metadata for the blocks.
In the case of REALM this will be [start_idx, end_idx, doc_idx]
In the case of retriever this will be [start_idx, end_idx, doc_idx]
"""
for
idx
,
embed
,
meta
in
zip
(
block_indices
,
block_embeds
,
block_meta
s
):
for
idx
,
embed
in
zip
(
row_id
,
block_embed
s
):
if
not
allow_overwrite
and
idx
in
self
.
embed_data
:
raise
ValueError
(
"Unexpectedly tried to overwrite block data"
)
self
.
embed_data
[
idx
]
=
np
.
float16
(
embed
)
self
.
meta_data
[
idx
]
=
meta
def
save_shard
(
self
):
"""Save the block data that was created this in this process"""
"""
Save the block data that was created this in this process
"""
if
not
os
.
path
.
isdir
(
self
.
temp_dir_name
):
os
.
makedirs
(
self
.
temp_dir_name
,
exist_ok
=
True
)
# save the data for each shard
with
open
(
'{}/{}.pkl'
.
format
(
self
.
temp_dir_name
,
self
.
rank
),
'wb'
)
as
data_file
:
pickle
.
dump
(
self
.
state
(),
data_file
)
with
open
(
'{}/{}.pkl'
.
format
(
self
.
temp_dir_name
,
self
.
rank
),
'wb'
)
\
as
writer
:
pickle
.
dump
(
self
.
state
(),
writer
)
def
merge_shards_and_save
(
self
):
"""
Combine all the shards made using
self.
save_shard
()"""
#
Combine all the shards made using save_shard
shard_names
=
os
.
listdir
(
self
.
temp_dir_name
)
seen_own_shard
=
False
...
...
@@ -96,15 +99,15 @@ class BlockData(object):
old_size
=
len
(
self
.
embed_data
)
shard_size
=
len
(
data
[
'embed_data'
])
# add the shard's data and check to make sure there is no overlap
# add the shard's data and check to make sure there
# is no overlap
self
.
embed_data
.
update
(
data
[
'embed_data'
])
self
.
meta_data
.
update
(
data
[
'meta_data'
])
assert
len
(
self
.
embed_data
)
==
old_size
+
shard_size
assert
seen_own_shard
# save the consolidated shards and remove temporary directory
with
open
(
self
.
block_data
_path
,
'wb'
)
as
final_file
:
with
open
(
self
.
embedding
_path
,
'wb'
)
as
final_file
:
pickle
.
dump
(
self
.
state
(),
final_file
)
shutil
.
rmtree
(
self
.
temp_dir_name
,
ignore_errors
=
True
)
...
...
@@ -113,18 +116,22 @@ class BlockData(object):
class
FaissMIPSIndex
(
object
):
"""Wrapper object for a BlockData which similarity search via FAISS under the hood"""
def
__init__
(
self
,
embed_size
,
block_data
=
None
,
use_gpu
=
False
):
"""
Wrapper object for a BlockData which similarity search via FAISS under the hood
"""
def
__init__
(
self
,
embed_size
,
embed_data
=
None
,
use_gpu
=
False
):
self
.
embed_size
=
embed_size
self
.
block
_data
=
block
_data
self
.
embed
_data
=
embed
_data
self
.
use_gpu
=
use_gpu
self
.
id_map
=
dict
()
self
.
block_
mips_index
=
None
self
.
_set_
block
_index
()
self
.
mips_index
=
None
self
.
_set_
mips
_index
()
def
_set_block_index
(
self
):
"""Create a Faiss Flat index with inner product as the metric to search against"""
def
_set_mips_index
(
self
):
"""
Create a Faiss Flat index with inner product as the metric
to search against
"""
try
:
import
faiss
except
ImportError
:
...
...
@@ -132,85 +139,86 @@ class FaissMIPSIndex(object):
if
mpu
.
is_unitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
"
\n
> Building index"
,
flush
=
True
)
self
.
block_mips_index
=
faiss
.
index_factory
(
self
.
embed_size
,
'Flat'
,
faiss
.
METRIC_INNER_PRODUCT
)
cpu_index
=
faiss
.
IndexFlatIP
(
self
.
embed_size
)
if
self
.
use_gpu
:
# create resources and config for GpuIndex
res
=
faiss
.
StandardGpuResources
()
config
=
faiss
.
GpuIndexFlatConfig
()
config
.
device
=
torch
.
cuda
.
current_device
()
config
=
faiss
.
GpuMultipleClonerOptions
()
config
.
shard
=
True
config
.
useFloat16
=
True
self
.
block_
mips_index
=
faiss
.
Gpu
Index
Flat
(
res
,
self
.
block_mips_index
,
config
)
gpu_index
=
faiss
.
index_cpu_to_all_gpus
(
cpu_index
,
co
=
config
)
self
.
mips_index
=
faiss
.
Index
IDMap
(
gpu_index
)
if
mpu
.
is_unitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
">> Initialized index on GPU
{}"
.
format
(
self
.
block_mips_index
.
getDevice
())
,
flush
=
True
)
print
(
">> Initialized index on GPU
"
,
flush
=
True
)
else
:
# CPU index supports IDs so wrap with IDMap
self
.
block_
mips_index
=
faiss
.
IndexIDMap
(
self
.
block_mips
_index
)
self
.
mips_index
=
faiss
.
IndexIDMap
(
cpu
_index
)
if
mpu
.
is_unitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
">> Initialized index on CPU"
,
flush
=
True
)
# if we were constructed with a BlockData, then automatically load it when the FAISS structure is built
if
self
.
block_data
is
not
None
:
self
.
add_block_embed_data
(
self
.
block_data
)
# if we were constructed with a BlockData, then automatically load it
# when the FAISS structure is built
if
self
.
embed_data
is
not
None
:
self
.
add_embed_data
(
self
.
embed_data
)
def
reset_index
(
self
):
"""Delete existing index and create anew"""
del
self
.
block_
mips_index
"""Delete existing index and create a
new"""
del
self
.
mips_index
# reset the block data so that _set_block_index will reload it as well
if
self
.
block
_data
is
not
None
:
block
_data_path
=
self
.
block
_data
.
block_data
_path
del
self
.
block
_data
self
.
block
_data
=
BlockData
(
block
_data_path
)
if
self
.
embed
_data
is
not
None
:
embed
_data_path
=
self
.
embed
_data
.
embedding
_path
del
self
.
embed
_data
self
.
embed
_data
=
OpenRetreivalDataStore
(
embed
_data_path
)
self
.
_set_
block
_index
()
self
.
_set_
mips
_index
()
def
add_block_embed_data
(
self
,
all_block_data
):
def
update_index
(
self
):
"""Delete existing index and create a new"""
del
self
.
mips_index
# reset the block data so that _set_mips_index will reload it as well
if
self
.
embed_data
is
not
None
:
self
.
embed_data
.
load_from_file
()
self
.
_set_mips_index
()
def
add_embed_data
(
self
,
all_embed_data
):
"""Add the embedding of each block to the underlying FAISS index"""
# this assumes the embed_data is a dict : {int: np.array<float>}
block_indices
,
block_embeds
=
zip
(
*
all_
block
_data
.
embed_data
.
items
())
block_indices
,
block_embeds
=
zip
(
*
all_
embed
_data
.
embed_data
.
items
())
# the embeddings have to be entered in as float32 even though the math internally is done with float16.
block_embeds_arr
=
np
.
float32
(
np
.
array
(
block_embeds
))
block_indices_arr
=
np
.
array
(
block_indices
)
# faiss GpuIndex doesn't work with IDMap wrapper so store ids to map back with
if
self
.
use_gpu
:
for
i
,
idx
in
enumerate
(
block_indices
):
self
.
id_map
[
i
]
=
idx
# the embeddings have to be entered in as float32 even though the math
# internally is done with float16.
embeds_arr
=
np
.
float32
(
np
.
array
(
block_embeds
))
indices_arr
=
np
.
array
(
block_indices
)
# we no longer need the embedding data since it's in the index now
all_
block
_data
.
clear
()
all_
embed
_data
.
clear
()
if
self
.
use_gpu
:
self
.
block_mips_index
.
add
(
block_embeds_arr
)
else
:
self
.
block_mips_index
.
add_with_ids
(
block_embeds_arr
,
block_indices_arr
)
self
.
mips_index
.
add_with_ids
(
embeds_arr
,
indices_arr
)
if
mpu
.
is_unitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
">>> Finished adding block data to index"
,
flush
=
True
)
def
search_mips_index
(
self
,
query_embeds
,
top_k
,
reconstruct
=
True
):
"""Get the top-k blocks by the index distance metric.
"""
Get the top-k blocks by the index distance metric.
:param reconstruct: if True: return a [num_queries x k x embed_dim] array of blocks
if False: return [num_queries x k] array of distances, and another for indices
:param reconstruct: if True: return a [num_queries x k x embed_dim]
array of blocks
if False: return [num_queries x k] array of
distances, and another for indices
"""
query_embeds
=
np
.
float32
(
detach
(
query_embeds
))
if
reconstruct
:
# get the vectors themselves
top_k_block_embeds
=
self
.
block_mips_index
.
search_and_reconstruct
(
query_embeds
,
top_k
)
top_k_block_embeds
=
self
.
mips_index
.
search_and_reconstruct
(
\
query_embeds
,
top_k
)
return
top_k_block_embeds
else
:
# get distances and indices of closest vectors
distances
,
block_indices
=
self
.
block_mips_index
.
search
(
query_embeds
,
top_k
)
if
self
.
use_gpu
:
fresh_indices
=
np
.
zeros
(
block_indices
.
shape
)
for
i
,
j
in
itertools
.
product
(
block_indices
.
shape
):
fresh_indices
[
i
,
j
]
=
self
.
id_map
[
block_indices
[
i
,
j
]]
block_indices
=
fresh_indices
distances
,
block_indices
=
self
.
mips_index
.
search
(
query_embeds
,
top_k
)
return
distances
,
block_indices
megatron/data/vit_dataset.py
0 → 100644
View file @
aed2f75e
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
torch
from
torchvision
import
datasets
,
transforms
from
megatron.data.autoaugment
import
ImageNetPolicy
def
build_train_valid_datasets
(
data_path
,
crop_size
=
224
,
color_jitter
=
True
):
# training dataset
train_data_path
=
os
.
path
.
join
(
data_path
[
0
],
"train"
)
normalize
=
transforms
.
Normalize
(
mean
=
[
0.5
,
0.5
,
0.5
],
std
=
[
0.5
,
0.5
,
0.5
])
process
=
[
transforms
.
RandomResizedCrop
(
crop_size
),
transforms
.
RandomHorizontalFlip
(),
]
if
color_jitter
:
process
+=
[
transforms
.
ColorJitter
(
brightness
=
0.4
,
contrast
=
0.4
,
saturation
=
0.4
,
hue
=
0.1
)
]
fp16_t
=
transforms
.
ConvertImageDtype
(
torch
.
half
)
process
+=
[
ImageNetPolicy
(),
transforms
.
ToTensor
(),
normalize
,
fp16_t
]
transform_train
=
transforms
.
Compose
(
process
)
train_data
=
datasets
.
ImageFolder
(
root
=
train_data_path
,
transform
=
transform_train
)
# validation dataset
val_data_path
=
os
.
path
.
join
(
data_path
[
0
],
"val"
)
transform_val
=
transforms
.
Compose
(
[
transforms
.
Resize
(
crop_size
),
transforms
.
CenterCrop
(
crop_size
),
transforms
.
ToTensor
(),
normalize
,
fp16_t
]
)
val_data
=
datasets
.
ImageFolder
(
root
=
val_data_path
,
transform
=
transform_val
)
return
train_data
,
val_data
megatron/fused_kernels/__init__.py
View file @
aed2f75e
...
...
@@ -13,114 +13,97 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
pathlib
import
subprocess
import
os
from
torch.utils
import
cpp_extension
# Setting this param to a list has a problem of generating
#
different
compilation commands (with diferent order of architectures)
#
and
leading to recompilation of fused kernels.
#
set it to empty string to avoid recompilatio
n
#
and assign arch flags explicity in
extra_cuda_cflags below
# Setting this param to a list has a problem of generating
different
# compilation commands (with diferent order of architectures)
and
# leading to recompilation of fused kernels.
Set it to empty string
#
to avoid recompilation and assign arch flags explicity i
n
# extra_cuda_cflags below
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
""
def
get_cuda_bare_metal_version
(
cuda_dir
):
raw_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
output
=
raw_output
.
split
()
release_idx
=
output
.
index
(
"release"
)
+
1
release
=
output
[
release_idx
].
split
(
"."
)
bare_metal_major
=
release
[
0
]
bare_metal_minor
=
release
[
1
][
0
]
return
raw_output
,
bare_metal_major
,
bare_metal_minor
def
create_build_dir
(
buildpath
):
try
:
os
.
mkdir
(
buildpath
)
except
OSError
:
if
not
os
.
path
.
isdir
(
buildpath
):
print
(
f
"Creation of the build directory
{
buildpath
}
failed"
)
def
load
_scaled_upper_triang_masked_softmax_fusion_kernel
(
):
def
load
(
args
):
# Check
,
if
CUDA
11 is installed for compute capability 8.0
# Check if
cuda
11 is installed for compute capability 8.0
cc_flag
=
[]
_
,
bare_metal_major
,
_
=
get_cuda_bare_metal_version
(
cpp_extension
.
CUDA_HOME
)
_
,
bare_metal_major
,
_
=
_get_cuda_bare_metal_version
(
cpp_extension
.
CUDA_HOME
)
if
int
(
bare_metal_major
)
>=
11
:
cc_flag
.
append
(
'-gencode'
)
cc_flag
.
append
(
'arch=compute_80,code=sm_80'
)
# Build path
srcpath
=
pathlib
.
Path
(
__file__
).
parent
.
absolute
()
buildpath
=
srcpath
/
'build'
create_build_dir
(
buildpath
)
scaled_upper_triang_masked_softmax_cuda
=
cpp_extension
.
load
(
name
=
'scaled_upper_triang_masked_softmax_cuda'
,
_create_build_dir
(
buildpath
)
# Helper function to build the kernels.
def
_cpp_extention_load_helper
(
name
,
sources
,
extra_cuda_flags
):
return
cpp_extension
.
load
(
name
=
name
,
sources
=
sources
,
build_directory
=
buildpath
,
extra_cflags
=
[
'-O3'
,],
extra_cuda_cflags
=
[
'-O3'
,
'-gencode'
,
'arch=compute_70,code=sm_70'
,
'--use_fast_math'
]
+
extra_cuda_flags
+
cc_flag
,
verbose
=
(
args
.
rank
==
0
)
)
# ==============
# Fused softmax.
# ==============
if
args
.
masked_softmax_fusion
:
extra_cuda_flags
=
[
'-U__CUDA_NO_HALF_OPERATORS__'
,
'-U__CUDA_NO_HALF_CONVERSIONS__'
,
'--expt-relaxed-constexpr'
,
'--expt-extended-lambda'
]
# Upper triangular softmax.
sources
=
[
srcpath
/
'scaled_upper_triang_masked_softmax.cpp'
,
srcpath
/
'scaled_upper_triang_masked_softmax_cuda.cu'
],
build_directory
=
buildpath
,
extra_cflags
=
[
'-O3'
,],
extra_cuda_cflags
=
[
'-O3'
,
'-gencode'
,
'arch=compute_70,code=sm_70'
,
'-U__CUDA_NO_HALF_OPERATORS__'
,
'-U__CUDA_NO_HALF_CONVERSIONS__'
,
'--expt-relaxed-constexpr'
,
'--expt-extended-lambda'
,
'--use_fast_math'
]
+
cc_flag
)
def
load_scaled_masked_softmax_fusion_kernel
():
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag
=
[]
_
,
bare_metal_major
,
_
=
get_cuda_bare_metal_version
(
cpp_extension
.
CUDA_HOME
)
if
int
(
bare_metal_major
)
>=
11
:
cc_flag
.
append
(
'-gencode'
)
cc_flag
.
append
(
'arch=compute_80,code=sm_80'
)
srcpath
/
'scaled_upper_triang_masked_softmax_cuda.cu'
]
scaled_upper_triang_masked_softmax_cuda
=
_cpp_extention_load_helper
(
"scaled_upper_triang_masked_softmax_cuda"
,
sources
,
extra_cuda_flags
)
srcpath
=
pathlib
.
Path
(
__file__
).
parent
.
absolute
()
buildpath
=
srcpath
/
'build'
# Masked softmax.
sources
=
[
srcpath
/
'scaled_masked_softmax.cpp'
,
srcpath
/
'scaled_masked_softmax_cuda.cu'
]
scaled_masked_softmax_cuda
=
_cpp_extention_load_helper
(
"scaled_masked_softmax_cuda"
,
sources
,
extra_cuda_flags
)
create_build_dir
(
buildpath
)
# =================================
# Mixed precision fused layer norm.
# =================================
scaled_upper_triang_masked_softmax_cuda
=
cpp_extension
.
load
(
name
=
'scaled_masked_softmax_cuda'
,
sources
=
[
srcpath
/
'scaled_masked_softmax.cpp'
,
srcpath
/
'scaled_masked_softmax_cuda.cu'
],
build_directory
=
buildpath
,
extra_cflags
=
[
'-O3'
,],
extra_cuda_cflags
=
[
'-O3'
,
'-gencode'
,
'arch=compute_70,code=sm_70'
,
'-U__CUDA_NO_HALF_OPERATORS__'
,
'-U__CUDA_NO_HALF_CONVERSIONS__'
,
'--expt-relaxed-constexpr'
,
'--expt-extended-lambda'
,
'--use_fast_math'
]
+
cc_flag
)
extra_cuda_flags
=
[
'-maxrregcount=50'
]
sources
=
[
srcpath
/
'layer_norm_cuda.cpp'
,
srcpath
/
'layer_norm_cuda_kernel.cu'
]
fused_mix_prec_layer_norm_cuda
=
_cpp_extention_load_helper
(
"fused_mix_prec_layer_norm_cuda"
,
sources
,
extra_cuda_flags
)
def
load_fused_mix_prec_layer_norm_kernel
():
def
_get_cuda_bare_metal_version
(
cuda_dir
):
raw_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
output
=
raw_output
.
split
()
release_idx
=
output
.
index
(
"release"
)
+
1
release
=
output
[
release_idx
].
split
(
"."
)
bare_metal_major
=
release
[
0
]
bare_metal_minor
=
release
[
1
][
0
]
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag
=
[]
_
,
bare_metal_major
,
_
=
get_cuda_bare_metal_version
(
cpp_extension
.
CUDA_HOME
)
if
int
(
bare_metal_major
)
>=
11
:
cc_flag
.
append
(
'-gencode'
)
cc_flag
.
append
(
'arch=compute_80,code=sm_80'
)
return
raw_output
,
bare_metal_major
,
bare_metal_minor
srcpath
=
pathlib
.
Path
(
__file__
).
parent
.
absolute
()
buildpath
=
srcpath
/
'build'
create_build_dir
(
buildpath
)
fused_mix_prec_layer_norm_cuda
=
cpp_extension
.
load
(
name
=
'fused_mix_prec_layer_norm_cuda'
,
sources
=
[
srcpath
/
'layer_norm_cuda.cpp'
,
srcpath
/
'layer_norm_cuda_kernel.cu'
],
build_directory
=
buildpath
,
extra_cflags
=
[
'-O3'
],
extra_cuda_cflags
=
[
'-O3'
,
'-gencode'
,
'arch=compute_70,code=sm_70'
,
'-maxrregcount=50'
,
'--use_fast_math'
]
+
cc_flag
)
def
_create_build_dir
(
buildpath
):
try
:
os
.
mkdir
(
buildpath
)
except
OSError
:
if
not
os
.
path
.
isdir
(
buildpath
):
print
(
f
"Creation of the build directory
{
buildpath
}
failed"
)
megatron/fused_kernels/layer_norm_cuda.cpp
View file @
aed2f75e
...
...
@@ -24,16 +24,12 @@
#include "compat.h"
namespace
{
void
compute_n1_n2
(
at
::
Tensor
input
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
int
&
n1
,
int
&
n2
)
{
int
&
n2
)
{
int
idiff
=
input
.
ndimension
()
-
normalized_shape
.
size
();
n2
=
1
;
for
(
int
i
=
0
;
i
<
(
int
)
normalized_shape
.
size
();
++
i
)
{
...
...
@@ -47,11 +43,7 @@ void compute_n1_n2(
}
void
check_args
(
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
at
::
Tensor
gamma
,
at
::
Tensor
beta
)
...
...
@@ -62,11 +54,7 @@ void check_args(
void
check_args
(
at
::
Tensor
input
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
int
&
n1
,
int
&
n2
)
...
...
@@ -102,11 +90,7 @@ void check_args(
void
check_args
(
at
::
Tensor
input
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
at
::
Tensor
gamma
,
at
::
Tensor
beta
,
int
&
n1
,
...
...
@@ -125,60 +109,42 @@ void cuda_layer_norm(
at
::
Tensor
*
input
,
int
n1
,
int
n2
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
at
::
Tensor
*
gamma
,
at
::
Tensor
*
beta
,
double
epsilon
);
#define CHECK_CUDA(x) TORCH_CHECK(x.
type().
is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std
::
vector
<
at
::
Tensor
>
layer_norm
(
at
::
Tensor
input
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
double
epsilon
)
{
CHECK_INPUT
(
input
);
int
n1
,
n2
;
check_args
(
input
,
normalized_shape
,
n1
,
n2
);
at
::
Tensor
output
=
at
::
empty_like
(
input
);
at
::
Tensor
mean
=
at
::
empty
({
n1
},
input
.
options
().
dtype
(
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
?
at
::
ScalarType
::
Float
:
input
.
scalar_type
()));
at
::
Tensor
invvar
=
at
::
empty_like
(
mean
);
cuda_layer_norm
(
&
output
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
normalized_shape
,
NULL
,
NULL
,
epsilon
);
return
{
output
,
mean
,
invvar
};
}
std
::
vector
<
at
::
Tensor
>
layer_norm_affine
(
at
::
Tensor
input
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
at
::
Tensor
gamma
,
at
::
Tensor
beta
,
double
epsilon
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
gamma
);
CHECK_INPUT
(
beta
);
int
n1
,
n2
;
check_args
(
input
,
normalized_shape
,
gamma
,
beta
,
n1
,
n2
);
at
::
Tensor
output
=
at
::
empty_like
(
input
,
input
.
options
().
dtype
(
at
::
ScalarType
::
Half
));
at
::
Tensor
mean
=
at
::
empty
({
n1
},
input
.
options
().
dtype
(
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
?
at
::
ScalarType
::
Float
:
input
.
scalar_type
()));
int
n1
,
n2
;
check_args
(
input
,
normalized_shape
,
gamma
,
beta
,
n1
,
n2
);
at
::
Tensor
output
=
at
::
empty_like
(
input
,
gamma
.
options
().
dtype
(
gamma
.
scalar_type
()));
at
::
Tensor
mean
=
at
::
empty
(
{
n1
},
input
.
options
().
dtype
(
at
::
ScalarType
::
Float
));
at
::
Tensor
invvar
=
at
::
empty_like
(
mean
);
cuda_layer_norm
(
&
output
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
normalized_shape
,
&
gamma
,
&
beta
,
epsilon
);
cuda_layer_norm
(
&
output
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
normalized_shape
,
&
gamma
,
&
beta
,
epsilon
);
return
{
output
,
mean
,
invvar
};
}
void
cuda_layer_norm_gradient
(
at
::
Tensor
*
dout
,
at
::
Tensor
*
mean
,
...
...
@@ -186,11 +152,7 @@ void cuda_layer_norm_gradient(
at
::
Tensor
*
input
,
int
n1
,
int
n2
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
at
::
Tensor
*
gamma
,
at
::
Tensor
*
beta
,
double
epsilon
,
...
...
@@ -199,62 +161,41 @@ void cuda_layer_norm_gradient(
at
::
Tensor
*
grad_beta
);
at
::
Tensor
layer_norm_gradient
(
at
::
Tensor
dout
,
at
::
Tensor
mean
,
at
::
Tensor
invvar
,
at
::
Tensor
input
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
double
epsilon
)
{
CHECK_INPUT
(
dout
);
CHECK_INPUT
(
mean
);
CHECK_INPUT
(
invvar
);
CHECK_INPUT
(
input
);
int
n1
,
n2
;
check_args
(
input
,
normalized_shape
,
n1
,
n2
);
at
::
Tensor
grad_input
=
at
::
empty_like
(
input
);
cuda_layer_norm_gradient
(
&
dout
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
normalized_shape
,
NULL
,
NULL
,
epsilon
,
&
grad_input
,
NULL
,
NULL
);
return
grad_input
;
}
std
::
vector
<
at
::
Tensor
>
layer_norm_gradient_affine
(
at
::
Tensor
dout
,
at
::
Tensor
mean
,
at
::
Tensor
invvar
,
at
::
Tensor
input
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
at
::
Tensor
gamma
,
at
::
Tensor
beta
,
double
epsilon
)
{
CHECK_INPUT
(
dout
);
CHECK_INPUT
(
mean
);
CHECK_INPUT
(
invvar
);
CHECK_INPUT
(
input
);
CHECK_INPUT
(
gamma
);
CHECK_INPUT
(
beta
);
int
n1
,
n2
;
check_args
(
input
,
normalized_shape
,
gamma
,
beta
,
n1
,
n2
);
int
n1
,
n2
;
check_args
(
input
,
normalized_shape
,
gamma
,
beta
,
n1
,
n2
);
at
::
Tensor
grad_input
=
at
::
empty_like
(
input
);
at
::
Tensor
grad_gamma
=
at
::
empty_like
(
gamma
);
at
::
Tensor
grad_beta
=
at
::
empty_like
(
beta
);
cuda_layer_norm_gradient
(
&
dout
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
normalized_shape
,
&
gamma
,
&
beta
,
epsilon
,
&
grad_input
,
&
grad_gamma
,
&
grad_beta
);
cuda_layer_norm_gradient
(
&
dout
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
normalized_shape
,
&
gamma
,
&
beta
,
epsilon
,
&
grad_input
,
&
grad_gamma
,
&
grad_beta
);
return
{
grad_input
,
grad_gamma
,
grad_beta
};
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward_affine"
,
&
layer_norm_affine
,
"LayerNorm forward (CUDA)"
);
m
.
def
(
"forward"
,
&
layer_norm
,
"LayerNorm forward (CUDA)"
);
m
.
def
(
"backward_affine"
,
&
layer_norm_gradient_affine
,
"LayerNorm backward (CUDA)"
);
m
.
def
(
"backward"
,
&
layer_norm_gradient
,
"LayerNorm backward (CUDA)"
);
m
.
def
(
"forward_affine"
,
&
layer_norm_affine
,
"LayerNorm forward (CUDA)"
);
m
.
def
(
"backward_affine"
,
&
layer_norm_gradient_affine
,
"LayerNorm backward (CUDA)"
);
}
megatron/fused_kernels/layer_norm_cuda_kernel.cu
View file @
aed2f75e
...
...
@@ -285,15 +285,6 @@ struct SharedMemory <float>
}
};
template
<
>
struct
SharedMemory
<
double
>
{
__device__
double
*
getPointer
()
{
extern
__shared__
double
s_double
[];
return
s_double
;
}
};
}
template
<
typename
T
,
typename
U
,
typename
V
>
__global__
...
...
@@ -656,6 +647,9 @@ void cuComputeGradInput(
}
}
template
<
typename
T
,
typename
U
,
typename
V
>
void
HostApplyLayerNorm
(
V
*
output
,
...
...
@@ -671,7 +665,8 @@ void HostApplyLayerNorm(
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
dim3
threads
(
32
,
4
,
1
);
const
uint64_t
maxGridY
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxGridSize
[
1
];
const
uint64_t
maxGridY
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxGridSize
[
1
];
const
dim3
blocks
(
1
,
std
::
min
((
uint64_t
)
n1
,
maxGridY
),
1
);
int
nshared
=
threads
.
y
>
1
?
...
...
@@ -687,6 +682,7 @@ void HostApplyLayerNorm(
gamma
,
beta
);
}
void
cuda_layer_norm
(
at
::
Tensor
*
output
,
at
::
Tensor
*
mean
,
...
...
@@ -704,21 +700,21 @@ void cuda_layer_norm(
double
epsilon
)
{
using
namespace
at
;
DISPATCH_DOUBLE_FLOAT_AND_HALF
(
input
->
scalar_type
(),
0
,
"layer_norm_cuda_kernel"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
using
output_t
=
at
::
Half
;
DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES
(
input
->
scalar_type
(),
output
->
scalar_type
(),
"cuda_layer_norm_kernel"
,
HostApplyLayerNorm
(
output
->
DATA_PTR
<
output_
t
>
(),
mean
->
DATA_PTR
<
accscalar_
t
>
(),
invvar
->
DATA_PTR
<
accscalar_
t
>
(),
input
->
DATA_PTR
<
scalar_t_
0
>
(),
output
->
DATA_PTR
<
scalar_t_ou
t
>
(),
mean
->
DATA_PTR
<
floa
t
>
(),
invvar
->
DATA_PTR
<
floa
t
>
(),
input
->
DATA_PTR
<
scalar_t_
in
>
(),
n1
,
n2
,
epsilon
,
gamma
!=
NULL
?
gamma
->
DATA_PTR
<
output_
t
>
()
:
NULL
,
beta
!=
NULL
?
beta
->
DATA_PTR
<
output_
t
>
()
:
NULL
);
gamma
!=
NULL
?
gamma
->
DATA_PTR
<
scalar_t_ou
t
>
()
:
NULL
,
beta
!=
NULL
?
beta
->
DATA_PTR
<
scalar_t_ou
t
>
()
:
NULL
);
)
}
template
<
typename
T
,
typename
U
,
typename
V
>
void
HostLayerNormGradient
(
const
V
*
dout
,
...
...
@@ -742,10 +738,12 @@ void HostLayerNormGradient(
const
int
part_size
=
16
;
const
dim3
threads2
(
32
,
4
,
1
);
const
dim3
blocks2
((
n2
+
threads2
.
x
-
1
)
/
threads2
.
x
,
part_size
,
1
);
const
int
nshared2_a
=
2
*
sizeof
(
U
)
*
threads2
.
y
*
threads2
.
y
*
(
threads2
.
x
+
1
);
const
int
nshared2_a
=
2
*
sizeof
(
U
)
*
threads2
.
y
*
threads2
.
y
*
(
threads2
.
x
+
1
);
const
int
nshared2_b
=
threads2
.
x
*
threads2
.
y
*
sizeof
(
U
);
const
int
nshared2
=
nshared2_a
>
nshared2_b
?
nshared2_a
:
nshared2_b
;
at
::
Tensor
part_grad_gamma
=
at
::
empty
({
part_size
,
n2
},
input
->
options
().
dtype
(
input
->
scalar_type
()
==
at
::
ScalarType
::
Half
?
at
::
ScalarType
::
Float
:
input
->
scalar_type
()));
at
::
Tensor
part_grad_gamma
=
at
::
empty
(
{
part_size
,
n2
},
input
->
options
().
dtype
(
at
::
ScalarType
::
Float
));
at
::
Tensor
part_grad_beta
=
at
::
empty_like
(
part_grad_gamma
);
cuComputePartGradGammaBeta
<<<
blocks2
,
threads2
,
nshared2
,
stream
>>>
(
dout
,
...
...
@@ -770,7 +768,8 @@ void HostLayerNormGradient(
}
// compute grad_input
const
uint64_t
maxGridY
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxGridSize
[
1
];
const
uint64_t
maxGridY
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxGridSize
[
1
];
const
dim3
blocks1
(
1
,
std
::
min
((
uint64_t
)
n1
,
maxGridY
),
1
);
const
dim3
threads1
(
32
,
4
,
1
);
int
nshared
=
...
...
@@ -788,6 +787,7 @@ void HostLayerNormGradient(
grad_input
);
}
void
cuda_layer_norm_gradient
(
at
::
Tensor
*
dout
,
at
::
Tensor
*
mean
,
...
...
@@ -808,22 +808,22 @@ void cuda_layer_norm_gradient(
at
::
Tensor
*
grad_beta
)
{
using
namespace
at
;
DISPATCH_FLOAT_
AND_HALF
(
input
->
scalar_type
(),
0
,
"cuComputeGradInput"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
using
output_t
=
at
::
Half
;
DISPATCH_FLOAT_
HALF_AND_BFLOAT_INOUT_TYPES
(
input
->
scalar_type
(),
gamma
->
scalar_type
(),
"cuda_layer_norm_gradient_kernel"
,
HostLayerNormGradient
(
dout
->
DATA_PTR
<
output_
t
>
(),
mean
->
DATA_PTR
<
accscalar_
t
>
(),
invvar
->
DATA_PTR
<
accscalar_
t
>
(),
dout
->
DATA_PTR
<
scalar_t_ou
t
>
(),
mean
->
DATA_PTR
<
floa
t
>
(),
invvar
->
DATA_PTR
<
floa
t
>
(),
input
,
n1
,
n2
,
// TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
// if gamma Tensor is NULL on input.
gamma
!=
NULL
?
gamma
->
DATA_PTR
<
output_
t
>
()
:
NULL
,
gamma
!=
NULL
?
beta
->
DATA_PTR
<
output_
t
>
()
:
NULL
,
gamma
!=
NULL
?
gamma
->
DATA_PTR
<
scalar_t_ou
t
>
()
:
NULL
,
gamma
!=
NULL
?
beta
->
DATA_PTR
<
scalar_t_ou
t
>
()
:
NULL
,
epsilon
,
grad_input
->
DATA_PTR
<
scalar_t_
0
>
(),
gamma
!=
NULL
?
grad_gamma
->
DATA_PTR
<
output_
t
>
()
:
NULL
,
gamma
!=
NULL
?
grad_beta
->
DATA_PTR
<
output_
t
>
()
:
NULL
);
grad_input
->
DATA_PTR
<
scalar_t_
in
>
(),
gamma
!=
NULL
?
grad_gamma
->
DATA_PTR
<
scalar_t_ou
t
>
()
:
NULL
,
gamma
!=
NULL
?
grad_beta
->
DATA_PTR
<
scalar_t_ou
t
>
()
:
NULL
);
)
}
megatron/fused_kernels/scaled_masked_softmax.cpp
View file @
aed2f75e
...
...
@@ -37,8 +37,9 @@ torch::Tensor fwd(
torch
::
Tensor
const
&
mask
,
float
scale_factor
)
{
AT_ASSERTM
(
input
.
dim
()
==
4
,
"expected 4D tensor"
);
AT_ASSERTM
(
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
((
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
input
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
AT_ASSERTM
(
mask
.
dim
()
==
4
,
"expected 4D tensor"
);
return
fwd_cuda
(
input
,
mask
,
scale_factor
);
...
...
@@ -52,10 +53,12 @@ torch::Tensor bwd(
AT_ASSERTM
(
output_grads
.
dim
()
==
4
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
4
,
"expected 3D tensor"
);
AT_ASSERTM
(
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
((
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
AT_ASSERTM
((
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
return
bwd_cuda
(
output_grads
,
softmax_results
,
scale_factor
);
}
...
...
megatron/fused_kernels/scaled_masked_softmax.h
View file @
aed2f75e
This diff is collapsed.
Click to expand it.
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
View file @
aed2f75e
...
...
@@ -19,10 +19,10 @@
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "scaled_masked_softmax.h"
#include "type_shim.h"
namespace
multihead_attn
{
namespace
fused_softmax
{
...
...
@@ -37,33 +37,39 @@ torch::Tensor fwd_cuda(
const
int
batches
=
input
.
size
(
0
);
const
int
pad_batches
=
mask
.
size
(
0
);
const
int
attn_heads
=
input
.
size
(
1
);
const
int
seq_len
=
input
.
size
(
2
);
TORCH_INTERNAL_ASSERT
(
seq_len
<=
2048
);
const
int
query_seq_len
=
input
.
size
(
2
);
const
int
key_seq_len
=
input
.
size
(
3
);
TORCH_INTERNAL_ASSERT
(
key_seq_len
<=
2048
);
TORCH_INTERNAL_ASSERT
(
query_seq_len
>
1
);
TORCH_INTERNAL_ASSERT
(
pad_batches
==
1
||
pad_batches
==
batches
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
1
)
==
1
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
2
)
==
seq_len
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
3
)
==
seq_len
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
2
)
==
query_
seq_len
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
3
)
==
key_
seq_len
);
// Output
auto
act_options
=
input
.
options
().
requires_grad
(
false
);
torch
::
Tensor
softmax_results
=
torch
::
empty
({
batches
,
attn_heads
,
seq_len
,
seq_len
},
act_options
);
torch
::
empty
({
batches
,
attn_heads
,
query_
seq_len
,
key_
seq_len
},
act_options
);
// Softmax Intermediate Result Ptr
void
*
input_ptr
=
static_cast
<
void
*>
(
input
.
data_ptr
());
void
*
mask_ptr
=
static_cast
<
void
*>
(
mask
.
data_ptr
());
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
dispatch_scaled_masked_softmax_forward
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
input_ptr
),
reinterpret_cast
<
const
uint8_t
*>
(
mask_ptr
),
scale_factor
,
seq_len
,
seq_len
,
batches
,
attn_heads
,
pad_batches
);
DISPATCH_HALF_AND_BFLOAT
(
input
.
scalar_type
(),
"dispatch_scaled_masked_softmax_forward"
,
dispatch_scaled_masked_softmax_forward
<
scalar_t
,
scalar_t
,
float
>
(
reinterpret_cast
<
scalar_t
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
scalar_t
*>
(
input_ptr
),
reinterpret_cast
<
const
uint8_t
*>
(
mask_ptr
),
scale_factor
,
query_seq_len
,
key_seq_len
,
batches
,
attn_heads
,
pad_batches
);
);
return
softmax_results
;
}
...
...
@@ -78,21 +84,25 @@ torch::Tensor bwd_cuda(
//output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const
int
batches
=
output_grads
.
size
(
0
);
const
int
attn_heads
=
output_grads
.
size
(
1
);
const
int
seq_len
=
output_grads
.
size
(
2
);
TORCH_INTERNAL_ASSERT
(
output_grads
.
size
(
2
)
=
=
output_grads
.
size
(
3
)
)
;
const
int
query_
seq_len
=
output_grads
.
size
(
2
);
const
int
key_seq_len
=
output_grads
.
size
(
3
);
void
*
output_grads_ptr
=
static_cast
<
void
*>
(
output_grads
.
data_ptr
());
//Softmax Grad
dispatch_scaled_masked_softmax_backward
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
output_grads_ptr
),
reinterpret_cast
<
half
*>
(
output_grads_ptr
),
reinterpret_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
scale_factor
,
seq_len
,
seq_len
,
batches
,
attn_heads
);
DISPATCH_HALF_AND_BFLOAT
(
output_grads_
.
scalar_type
(),
"dispatch_scaled_masked_softmax_backward"
,
dispatch_scaled_masked_softmax_backward
<
scalar_t
,
scalar_t
,
float
>
(
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
reinterpret_cast
<
scalar_t
const
*>
(
softmax_results
.
data_ptr
()),
scale_factor
,
query_seq_len
,
key_seq_len
,
batches
,
attn_heads
);
);
//backward pass is completely in-place
return
output_grads
;
...
...
megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp
View file @
aed2f75e
...
...
@@ -33,8 +33,9 @@ torch::Tensor bwd_cuda(
torch
::
Tensor
fwd
(
torch
::
Tensor
const
&
input
,
float
scale_factor
)
{
AT_ASSERTM
(
input
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
((
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
input
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
return
fwd_cuda
(
input
,
scale_factor
);
}
...
...
@@ -47,10 +48,12 @@ torch::Tensor bwd(
AT_ASSERTM
(
output_grads
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
((
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
AT_ASSERTM
((
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
return
bwd_cuda
(
output_grads
,
softmax_results
,
scale_factor
);
}
...
...
@@ -61,7 +64,7 @@ torch::Tensor bwd(
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
multihead_attn
::
fused_softmax
::
scaled_upper_triang_masked_softmax
::
fwd
,
&
multihead_attn
::
fused_softmax
::
scaled_upper_triang_masked_softmax
::
fwd
,
"Self Multihead Attention scaled, time masked softmax -- Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
fused_softmax
::
scaled_upper_triang_masked_softmax
::
bwd
,
...
...
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
View file @
aed2f75e
...
...
@@ -21,11 +21,47 @@
#include <cfloat>
#include <limits>
#include <stdint.h>
#include <cuda_fp16.h>
#include <c10/macros/Macros.h>
namespace
{
template
<
typename
Datatype
,
int
ELEMENTS_PER_LDG
>
__device__
__inline__
void
copy_vector
(
Datatype
*
dst
,
const
Datatype
*
src
);
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
BFloat16
,
1
>
(
c10
::
BFloat16
*
dst
,
const
c10
::
BFloat16
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
BFloat16
,
4
>
(
c10
::
BFloat16
*
dst
,
const
c10
::
BFloat16
*
src
)
{
*
((
float2
*
)
dst
)
=
*
((
float2
*
)
src
);
}
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
Half
,
1
>
(
c10
::
Half
*
dst
,
const
c10
::
Half
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
Half
,
4
>
(
c10
::
Half
*
dst
,
const
c10
::
Half
*
src
)
{
*
((
float2
*
)
dst
)
=
*
((
float2
*
)
src
);
}
template
<
>
__device__
__inline__
void
copy_vector
<
uint8_t
,
1
>
(
uint8_t
*
dst
,
const
uint8_t
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
uint8_t
,
4
>
(
uint8_t
*
dst
,
const
uint8_t
*
src
)
{
*
((
half2
*
)
dst
)
=
*
((
half2
*
)
src
);
}
template
<
typename
Datatype
,
int
ELEMENTS_PER_LDG
>
__device__
__inline__
void
copy_zero_vector
(
Datatype
*
dst
);
template
<
>
__device__
__inline__
void
copy_zero_vector
<
c10
::
BFloat16
,
1
>
(
c10
::
BFloat16
*
dst
)
{
*
dst
=
0.0
;
}
template
<
>
__device__
__inline__
void
copy_zero_vector
<
c10
::
BFloat16
,
4
>
(
c10
::
BFloat16
*
dst
)
{
*
((
float2
*
)
dst
)
=
make_float2
(
0.0
f
,
0.0
f
);
}
template
<
>
__device__
__inline__
void
copy_zero_vector
<
c10
::
Half
,
1
>
(
c10
::
Half
*
dst
)
{
*
dst
=
0.0
;
}
template
<
>
__device__
__inline__
void
copy_zero_vector
<
c10
::
Half
,
4
>
(
c10
::
Half
*
dst
)
{
*
((
float2
*
)
dst
)
=
make_float2
(
0.0
f
,
0.0
f
);
}
int
log2_ceil
(
int
value
)
{
int
log2_value
=
0
;
while
((
1
<<
log2_value
)
<
value
)
++
log2_value
;
...
...
@@ -73,7 +109,7 @@ __device__ __forceinline__ void warp_reduce(acc_t* sum) {
* Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling
* 2) Implicit time (diagonal masking)
*/
*/
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
log2_elements
>
__global__
void
scaled_upper_triang_masked_softmax_warp_forward
(
output_t
*
dst
,
...
...
@@ -89,10 +125,11 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
ELEMENTS_PER_LDG_STG
=
4
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
y
+
threadIdx
.
y
)
*
gridDim
.
x
*
WARP_BATCH
+
blockIdx
.
x
;
int
local_seq
=
blockIdx
.
x
+
1
;
int
warp_iteration_limit
=
(
local_seq
+
WARP_SIZE
-
1
)
/
WARP_SIZE
;
int
warp_iteration_limit
=
(
local_seq
+
ELEMENTS_PER_LDG_STG
*
WARP_SIZE
-
1
)
/
WARP_SIZE
;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
...
...
@@ -103,22 +140,36 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
// there might be multiple batches per warp. compute the index within the batch
int
local_idx
=
threadIdx
.
x
;
src
+=
first_batch
*
stride
+
local_idx
;
dst
+=
first_batch
*
stride
+
local_idx
;
src
+=
first_batch
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
dst
+=
first_batch
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
// load data from global memory
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
input_t
temp_data
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
local_seq
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
elements
[
i
][
it
]
=
(
acc_t
)
src
[
i
*
element_count
*
stride
+
it
*
WARP_SIZE
]
*
scale
;
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_data
,
src
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
);
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
if
((
element_index
+
element
)
<
batch_element_count
)
{
elements
[
i
][
it
+
element
]
=
(
acc_t
)
temp_data
[
element
]
*
scale
;
}
else
{
elements
[
i
][
it
+
element
]
=
-
std
::
numeric_limits
<
acc_t
>::
infinity
();
}
}
}
else
{
elements
[
i
][
it
]
=
-
std
::
numeric_limits
<
acc_t
>::
infinity
();
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
elements
[
i
][
it
+
element
]
=
-
std
::
numeric_limits
<
acc_t
>::
infinity
();
}
}
}
}
...
...
@@ -140,26 +191,37 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
if
(
it
<
warp_iteration_limit
)
{
if
(
it
<
warp_iteration_limit
)
{
elements
[
i
][
it
]
=
std
::
exp
((
elements
[
i
][
it
]
-
max_value
[
i
]));
sum
[
i
]
+=
elements
[
i
][
it
];
}
}
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Add
>
(
sum
);
// store result
output_t
out
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
local_seq
)
{
dst
[
i
*
element_count
*
stride
+
it
*
WARP_SIZE
]
=
(
output_t
)(
elements
[
i
][
it
]
/
sum
[
i
]);
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
if
(
element_index
+
element
<
local_seq
)
{
out
[
element
]
=
elements
[
i
][
it
+
element
]
/
sum
[
i
];
}
else
{
out
[
element
]
=
0
;
}
}
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
dst
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
,
out
);
}
else
if
(
element_index
<
element_count
)
{
dst
[
i
*
element_count
*
stride
+
it
*
WARP_SIZE
]
=
0
;
copy_zero_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
dst
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
)
;
}
else
{
break
;
}
...
...
@@ -183,6 +245,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
ELEMENTS_PER_LDG_STG
=
4
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
y
+
threadIdx
.
y
)
*
gridDim
.
x
*
WARP_BATCH
+
blockIdx
.
x
;
int
local_seq
=
blockIdx
.
x
+
1
;
...
...
@@ -197,37 +260,41 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
int
local_idx
=
threadIdx
.
x
;
// the first element to process by the current thread
int
thread_offset
=
first_batch
*
stride
+
local_idx
;
int
thread_offset
=
first_batch
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
grad
+=
thread_offset
;
output
+=
thread_offset
;
gradInput
+=
thread_offset
;
// load data from global memory
acc_t
grad_reg
[
WARP_BATCH
][
WARP_ITERATIONS
]
{
0.0
f
};
acc_t
output_reg
[
WARP_BATCH
][
WARP_ITERATIONS
];
acc_t
output_reg
[
WARP_BATCH
][
WARP_ITERATIONS
]
{
0.0
f
};
input_t
temp_grad
[
ELEMENTS_PER_LDG_STG
];
input_t
temp_output
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
local_seq
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
output_reg
[
i
][
it
]
=
output
[
i
*
element_count
*
stride
+
it
*
WARP_SIZE
];
}
else
{
output_reg
[
i
][
it
]
=
acc_t
(
0
);
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_grad
,
grad
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
);
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_output
,
output
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
);
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
if
(
element_index
+
element
<
batch_element_count
)
{
output_reg
[
i
][
it
+
element
]
=
(
acc_t
)
temp_output
[
element
];
}
}
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
if
(
element_index
+
element
<
batch_element_count
)
{
grad_reg
[
i
][
it
+
element
]
=
(
acc_t
)
temp_grad
[
element
]
*
output_reg
[
i
][
it
+
element
];
}
}
}
}
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
grad_reg
[
i
][
it
]
=
(
acc_t
)
grad
[
i
*
element_count
*
stride
+
it
*
WARP_SIZE
]
*
output_reg
[
i
][
it
];
}
else
{
grad_reg
[
i
][
it
]
=
acc_t
(
0
);
}
}
}
acc_t
sum
[
WARP_BATCH
];
...
...
@@ -247,11 +314,16 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
// compute gradients
gradInput
[
i
*
element_count
*
stride
+
it
*
WARP_SIZE
]
=
(
output_t
)(
scale
*
(
grad_reg
[
i
][
it
]
-
output_reg
[
i
][
it
]
*
sum
[
i
]));
output_t
out
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
out
[
element
]
=
(
output_t
)(
scale
*
(
grad_reg
[
i
][
it
+
element
]
-
output_reg
[
i
][
it
+
element
]
*
sum
[
i
]));
}
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
gradInput
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
,
out
);
}
}
}
...
...
megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu
View file @
aed2f75e
...
...
@@ -19,10 +19,10 @@
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "scaled_upper_triang_masked_softmax.h"
#include "type_shim.h"
namespace
multihead_attn
{
namespace
fused_softmax
{
...
...
@@ -46,15 +46,20 @@ torch::Tensor fwd_cuda(
void
*
input_ptr
=
static_cast
<
void
*>
(
input
.
data_ptr
());
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
dispatch_scaled_upper_triang_masked_softmax_forward
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
input_ptr
),
scale_factor
,
seq_len
,
seq_len
,
attn_batches
);
DISPATCH_HALF_AND_BFLOAT
(
input
.
scalar_type
(),
"dispatch_scaled_upper_triang_masked_softmax_forward"
,
dispatch_scaled_upper_triang_masked_softmax_forward
<
scalar_t
,
scalar_t
,
float
>
(
reinterpret_cast
<
scalar_t
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
scalar_t
*>
(
input_ptr
),
scale_factor
,
seq_len
,
seq_len
,
attn_batches
);
);
return
softmax_results
;
}
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
const
&
output_grads_
,
...
...
@@ -72,14 +77,18 @@ torch::Tensor bwd_cuda(
void
*
output_grads_ptr
=
static_cast
<
void
*>
(
output_grads
.
data_ptr
());
//Softmax Grad
dispatch_scaled_upper_triang_masked_softmax_backward
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
output_grads_ptr
),
reinterpret_cast
<
half
*>
(
output_grads_ptr
),
reinterpret_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
scale_factor
,
seq_len
,
seq_len
,
attn_batches
);
DISPATCH_HALF_AND_BFLOAT
(
output_grads_
.
scalar_type
(),
"dispatch_scaled_upper_triang_masked_softmax_backward"
,
dispatch_scaled_upper_triang_masked_softmax_backward
<
scalar_t
,
scalar_t
,
float
>
(
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
reinterpret_cast
<
scalar_t
const
*>
(
softmax_results
.
data_ptr
()),
scale_factor
,
seq_len
,
seq_len
,
attn_batches
);
);
//backward pass is completely in-place
return
output_grads
;
...
...
megatron/fused_kernels/type_shim.h
View file @
aed2f75e
...
...
@@ -14,214 +14,78 @@
* limitations under the License.
*/
/*This code is copied fron NVIDIA apex:
* https://github.com/NVIDIA/apex
* with minor changes. */
#include <ATen/ATen.h>
#include "compat.h"
// Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream.
// struct TypeShim
// {
// const at::Type& payload;
// TypeShim(const at::Type& type) : payload(type) {}
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
// operator const at::Type&(){ return payload; };
// // Enable dispatch switch statements to take *this directly for post-3aeb78
// //operator at::ScalarType(){ return payload.; };
// };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Half: \
{ \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch(TYPEIN) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_in = float; \
switch(TYPEOUT) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
} \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_in = at::Half; \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_in = at::BFloat16; \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Byte: \
{ \
using scalar_t_##LEVEL = uint8_t; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template
<
typename
T
>
__device__
__forceinline__
T
reduce_block_into_lanes
(
T
*
x
,
T
val
,
int
lanes
=
1
,
bool
share_result
=
false
)
// lanes is intended to be <= 32.
{
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
int
blockSize
=
blockDim
.
x
*
blockDim
.
y
;
// blockSize is intended to be a multiple of 32.
if
(
blockSize
>=
64
)
{
x
[
tid
]
=
val
;
__syncthreads
();
}
#pragma unroll
for
(
int
i
=
(
blockSize
>>
1
);
i
>=
64
;
i
>>=
1
)
{
if
(
tid
<
i
)
x
[
tid
]
=
x
[
tid
]
+
x
[
tid
+
i
];
__syncthreads
();
}
T
final
;
if
(
tid
<
32
)
{
if
(
blockSize
>=
64
)
final
=
x
[
tid
]
+
x
[
tid
+
32
];
else
final
=
val
;
// __SYNCWARP();
#pragma unroll
for
(
int
i
=
16
;
i
>=
lanes
;
i
>>=
1
)
final
=
final
+
__shfl_down_sync
(
0xffffffff
,
final
,
i
);
}
if
(
share_result
)
{
if
(
tid
<
lanes
)
x
[
tid
]
=
final
;
// EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads
();
}
return
final
;
}
template
<
typename
T
>
__device__
__forceinline__
T
reduce_block_into_lanes_max_op
(
T
*
x
,
T
val
,
int
lanes
=
1
,
bool
share_result
=
false
)
// lanes is intended to be <= 32.
{
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
int
blockSize
=
blockDim
.
x
*
blockDim
.
y
;
// blockSize is intended to be a multiple of 32.
if
(
blockSize
>=
64
)
{
x
[
tid
]
=
val
;
__syncthreads
();
}
#pragma unroll
for
(
int
i
=
(
blockSize
>>
1
);
i
>=
64
;
i
>>=
1
)
{
if
(
tid
<
i
)
x
[
tid
]
=
fmaxf
(
fabsf
(
x
[
tid
]),
fabsf
(
x
[
tid
+
i
]));
__syncthreads
();
}
T
final
;
if
(
tid
<
32
)
{
if
(
blockSize
>=
64
)
final
=
fmaxf
(
fabsf
(
x
[
tid
]),
fabsf
(
x
[
tid
+
32
]));
else
final
=
val
;
// __SYNCWARP();
#pragma unroll
for
(
int
i
=
16
;
i
>=
lanes
;
i
>>=
1
)
final
=
fmaxf
(
fabsf
(
final
),
fabsf
(
__shfl_down_sync
(
0xffffffff
,
final
,
i
)));
}
if
(
share_result
)
{
if
(
tid
<
lanes
)
x
[
tid
]
=
final
;
// EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads
();
}
return
final
;
}
megatron/global_vars.py
View file @
aed2f75e
...
...
@@ -83,7 +83,8 @@ def set_global_variables(extra_args_provider=None, args_defaults={},
defaults
=
args_defaults
,
ignore_unknown_args
=
ignore_unknown_args
)
_build_num_microbatches_calculator
(
args
)
_
=
_build_tokenizer
(
args
)
if
args
.
vocab_file
:
_
=
_build_tokenizer
(
args
)
_set_tensorboard_writer
(
args
)
_set_adlr_autoresume
(
args
)
_set_timers
()
...
...
@@ -131,12 +132,13 @@ def _set_tensorboard_writer(args):
'tensorboard writer'
)
if
hasattr
(
args
,
'tensorboard_dir'
)
and
\
args
.
tensorboard_dir
and
args
.
rank
==
(
args
.
world_size
-
1
):
args
.
tensorboard_dir
and
args
.
rank
==
(
args
.
world_size
-
1
):
try
:
from
torch.utils.tensorboard
import
SummaryWriter
print
(
'> setting tensorboard ...'
)
_GLOBAL_TENSORBOARD_WRITER
=
SummaryWriter
(
log_dir
=
args
.
tensorboard_dir
)
log_dir
=
args
.
tensorboard_dir
,
max_queue
=
args
.
tensorboard_queue_size
)
except
ModuleNotFoundError
:
print
(
'WARNING: TensorBoard writing requested but is not '
'available (are you using PyTorch 1.1.0 or later?), '
...
...
megatron/indexer.py
View file @
aed2f75e
import
sys
import
torch
import
torch.distributed
as
dist
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron.checkpointing
import
load_
ict
_checkpoint
from
megatron.data.
ict
_dataset
import
get_
ict
_dataset
from
megatron.data.
realm
_dataset
_utils
import
get_o
ne_epoch_dataloader
from
megatron.data.
realm_index
import
d
et
ach
,
BlockData
from
megatron.data.realm_
dataset_utils
import
get_ict_batch
from
megatron.model.
realm
_model
import
general_ict
_model_provider
from
megatron.checkpointing
import
load_
biencoder
_checkpoint
from
megatron.data.
orqa_wiki
_dataset
import
get_
open_retrieval_wiki
_dataset
from
megatron.data.
orqa_wiki
_dataset
import
get_o
pen_retrieval_batch
from
megatron.data.
biencoder_dataset_utils
import
g
et
_one_epoch_dataloader
from
megatron.data.realm_
index
import
detach
,
OpenRetreivalDataStore
from
megatron.model.
biencoder
_model
import
biencoder
_model_provider
from
megatron.training
import
get_model
class
IndexBuilder
(
object
):
"""Object for taking one pass over a dataset and creating a BlockData of its embeddings"""
"""
Object for taking one pass over a dataset and creating a BlockData of its
embeddings
"""
def
__init__
(
self
):
args
=
get_args
()
self
.
model
=
None
self
.
dataloader
=
None
self
.
block_data
=
None
self
.
evidence_embedder_obj
=
None
self
.
biencoder_shared_query_context_model
=
\
args
.
biencoder_shared_query_context_model
# need to know whether we're using a REALM checkpoint (args.load) or ICT checkpoint
# need to know whether we're using a REALM checkpoint (args.load)
# or ICT checkpoint
assert
not
(
args
.
load
and
args
.
ict_load
)
self
.
using_realm_chkpt
=
args
.
ict_load
is
None
#
self.using_realm_chkpt = args.ict_load is None
self
.
log_interval
=
args
.
indexer_log_interval
self
.
batch_size
=
args
.
indexer_batch_size
...
...
@@ -33,59 +40,88 @@ class IndexBuilder(object):
self
.
iteration
=
self
.
total_processed
=
0
def
load_attributes
(
self
):
"""Load the necessary attributes: model, dataloader and empty BlockData"""
model
=
get_model
(
lambda
:
general_ict_model_provider
(
only_block_model
=
True
))
self
.
model
=
load_ict_checkpoint
(
model
,
only_block_model
=
True
,
from_realm_chkpt
=
self
.
using_realm_chkpt
)
self
.
model
.
eval
()
self
.
dataset
=
get_ict_dataset
()
self
.
dataloader
=
iter
(
get_one_epoch_dataloader
(
self
.
dataset
,
self
.
batch_size
))
self
.
block_data
=
BlockData
(
load_from_path
=
False
)
"""
Load the necessary attributes: model, dataloader and empty BlockData
"""
only_context_model
=
True
if
self
.
biencoder_shared_query_context_model
:
only_context_model
=
False
model
=
get_model
(
lambda
:
biencoder_model_provider
(
only_context_model
\
=
only_context_model
,
biencoder_shared_query_context_model
=
\
self
.
biencoder_shared_query_context_model
))
self
.
model
=
load_biencoder_checkpoint
(
model
,
only_context_model
=
only_context_model
)
assert
len
(
self
.
model
)
==
1
self
.
model
[
0
].
eval
()
self
.
dataset
=
get_open_retrieval_wiki_dataset
()
self
.
dataloader
=
iter
(
get_one_epoch_dataloader
(
self
.
dataset
,
\
self
.
batch_size
))
self
.
evidence_embedder_obj
=
OpenRetreivalDataStore
(
\
load_from_path
=
False
)
def
track_and_report_progress
(
self
,
batch_size
):
"""Utility function for tracking progress"""
"""
Utility function for tracking progress
"""
self
.
iteration
+=
1
self
.
total_processed
+=
batch_size
*
self
.
num_total_builders
if
self
.
is_main_builder
and
self
.
iteration
%
self
.
log_interval
==
0
:
print
(
'Batch {:10d} | Total {:10d}'
.
format
(
self
.
iteration
,
self
.
total_processed
),
flush
=
True
)
print
(
'Batch {:10d} | Total {:10d}'
.
format
(
self
.
iteration
,
self
.
total_processed
),
flush
=
True
)
def
build_and_save_index
(
self
):
"""Goes through one epoch of the dataloader and adds all data to this instance's BlockData.
"""
Goes through one epoch of the dataloader and adds all data to this
instance's BlockData.
The copy of BlockData is saved as a shard, which when run in a distributed setting will be
consolidated by the rank 0 process and saved as a final pickled BlockData.
The copy of BlockData is saved as a shard, which when run in a
distributed setting will be consolidated by the rank 0 process
and saved as a final pickled BlockData.
"""
assert
len
(
self
.
model
)
==
1
unwrapped_model
=
self
.
model
[
0
]
while
not
hasattr
(
unwrapped_model
,
'embed_text'
):
unwrapped_model
=
unwrapped_model
.
module
while
True
:
try
:
# batch also has query_tokens and query_pad_data
_
,
_
,
block_tokens
,
block_pad_mask
,
block_sample_data
=
get_ict_batch
(
self
.
dataloader
)
row_id
,
context_tokens
,
context_mask
,
context_types
,
\
context_pad_mask
=
get_open_retrieval_batch
(
\
self
.
dataloader
)
except
(
StopIteration
,
IndexError
):
break
unwrapped_model
=
self
.
model
while
not
hasattr
(
unwrapped_model
,
'embed_block'
):
unwrapped_model
=
unwrapped_model
.
module
# TODO: can we add with torch.no_grad() to reduce memory usage
# detach, separate fields and add to BlockData
block_logits
=
detach
(
unwrapped_model
.
embed_block
(
block_tokens
,
block_pad_mask
))
detached_data
=
detach
(
block_sample_data
)
# block_sample_data is a 2D array [batch x 4]
# with columns [start_idx, end_idx, doc_idx, block_idx] same as class BlockSampleData
block_indices
=
detached_data
[:,
3
]
block_metas
=
detached_data
[:,
:
3
]
self
.
block_data
.
add_block_data
(
block_indices
,
block_logits
,
block_metas
)
self
.
track_and_report_progress
(
batch_size
=
block_tokens
.
shape
[
0
])
#
This process signals to finalize its shard and then synchronize with
the other processes
self
.
block_data
.
save_shard
()
assert
context_mask
.
dtype
==
torch
.
bool
context_logits
=
unwrapped_model
.
embed_text
(
unwrapped_model
.
context_model
,
context_tokens
,
context_mask
,
context_types
)
context_logits
=
detach
(
context_logits
)
row_id
=
detach
(
row_id
)
self
.
evidence_embedder_obj
.
add_block_data
(
row_id
,
context_logits
)
self
.
track_and_report_progress
(
batch_size
=
len
(
row_id
)
)
# This process signals to finalize its shard and then synchronize with
# the other processes
self
.
evidence_embedder_obj
.
save_shard
()
torch
.
distributed
.
barrier
()
del
self
.
model
# rank 0 process builds the final copy
if
self
.
is_main_builder
:
self
.
block_data
.
merge_shards_and_save
()
self
.
evidence_embedder_obj
.
merge_shards_and_save
()
# make sure that every single piece of data was embedded
assert
len
(
self
.
block_data
.
embed_data
)
==
len
(
self
.
dataset
)
self
.
block_data
.
clear
()
assert
len
(
self
.
evidence_embedder_obj
.
embed_data
)
==
\
len
(
self
.
dataset
)
self
.
evidence_embedder_obj
.
clear
()
# complete building the final copy
torch
.
distributed
.
barrier
()
megatron/initialize.py
View file @
aed2f75e
...
...
@@ -17,16 +17,20 @@
import
random
import
os
import
time
import
numpy
as
np
import
torch
from
megatron
import
fused_kernels
from
megatron
import
get_adlr_autoresume
from
megatron
import
get_args
from
megatron
import
get_tensorboard_writer
from
megatron
import
mpu
from
megatron.global_vars
import
set_global_variables
from
megatron.mpu
import
set_tensor_model_parallel_rank
,
set_tensor_model_parallel_world_size
from
megatron.mpu
import
(
set_tensor_model_parallel_rank
,
set_tensor_model_parallel_world_size
)
def
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{},
ignore_unknown_args
=
False
,
allow_no_cuda
=
False
):
...
...
@@ -37,8 +41,7 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
what you are doing.
Returns a function to finalize distributed env initialization
(optionally, only when args.lazy_mpu_init == True)
"""
"""
if
not
allow_no_cuda
:
# Make sure cuda is available.
assert
torch
.
cuda
.
is_available
(),
'Megatron requires CUDA.'
...
...
@@ -66,7 +69,8 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
# delayed initialization of DDP-related stuff
# We only set basic DDP globals
set_tensor_model_parallel_world_size
(
args
.
tensor_model_parallel_size
)
# and return function for external DDP manager to call when it has DDP initialized
# and return function for external DDP manager
# to call when it has DDP initialized
set_tensor_model_parallel_rank
(
args
.
rank
)
return
finish_mpu_init
else
:
...
...
@@ -79,19 +83,71 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
# Autoresume.
_init_autoresume
()
# Compile dataset C++ code.
try
:
from
megatron.data
import
helpers
except
:
if
torch
.
distributed
.
get_rank
()
==
0
:
from
megatron.data.dataset_utils
import
compile_helper
compile_helper
()
# Simple barrier
torch
.
distributed
.
barrier
()
# Compile dependencies.
_compile_dependencies
()
# No continuation function
return
None
def
_compile_dependencies
():
args
=
get_args
()
# =========================
# Compile dataset C++ code.
# =========================
# TODO: move this to ninja
if
torch
.
distributed
.
get_rank
()
==
0
:
start_time
=
time
.
time
()
print
(
'> compiling dataset index builder ...'
)
from
megatron.data.dataset_utils
import
compile_helper
compile_helper
()
print
(
'>>> done with dataset index builder. Compilation time: {:.3f} '
'seconds'
.
format
(
time
.
time
()
-
start_time
),
flush
=
True
)
# ==================
# Load fused kernels
# ==================
# Custom kernel constraints check.
seq_len
=
args
.
seq_length
attn_batch_size
=
\
(
args
.
num_attention_heads
/
args
.
tensor_model_parallel_size
)
*
\
args
.
micro_batch_size
# Constraints on sequence length and attn_batch_size to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint
=
seq_len
>
16
and
seq_len
<=
2048
and
\
seq_len
%
4
==
0
and
attn_batch_size
%
4
==
0
# Print a warning.
if
not
((
args
.
fp16
or
args
.
bf16
)
and
custom_kernel_constraint
and
args
.
masked_softmax_fusion
):
if
args
.
rank
==
0
:
print
(
'WARNING: constraints for invoking optimized'
' fused softmax kernel are not met. We default'
' back to unfused kernel invocations.'
,
flush
=
True
)
# Always build on rank zero first.
if
torch
.
distributed
.
get_rank
()
==
0
:
start_time
=
time
.
time
()
print
(
'> compiling and loading fused kernels ...'
,
flush
=
True
)
fused_kernels
.
load
(
args
)
torch
.
distributed
.
barrier
()
else
:
torch
.
distributed
.
barrier
()
fused_kernels
.
load
(
args
)
# Simple barrier to make sure all ranks have passed the
# compilation phase successfully before moving on to the
# rest of the program. We think this might ensure that
# the lock is released.
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'>>> done with compiling and loading fused kernels. '
'Compilation time: {:.3f} seconds'
.
format
(
time
.
time
()
-
start_time
),
flush
=
True
)
def
_initialize_distributed
():
"""Initialize torch.distributed and mpu."""
...
...
@@ -136,7 +192,8 @@ def _initialize_distributed():
print
(
'model parallel is already initialized'
)
else
:
mpu
.
initialize_model_parallel
(
args
.
tensor_model_parallel_size
,
args
.
pipeline_model_parallel_size
)
args
.
pipeline_model_parallel_size
,
args
.
virtual_pipeline_model_parallel_size
)
def
_init_autoresume
():
...
...
megatron/model/__init__.py
View file @
aed2f75e
...
...
@@ -13,34 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
_LAYER_NORM
=
None
from
.fused_layer_norm
import
MixedFusedLayerNorm
as
LayerNorm
def
import_layernorm
(
fp32_residual_connection
):
global
_LAYER_NORM
if
not
_LAYER_NORM
:
if
fp32_residual_connection
:
from
.fused_layer_norm
import
MixedFusedLayerNorm
as
LayerNorm
else
:
from
apex.normalization.fused_layer_norm
import
FusedLayerNorm
as
LayerNorm
_LAYER_NORM
=
LayerNorm
return
_LAYER_NORM
from
.distributed
import
*
from
.bert_model
import
(
BertModel
,
BertModelFirstStage
,
BertModelIntermediateStage
,
BertModelLastStage
)
from
.realm_model
import
ICTBertModel
from
.gpt_model
import
(
GPTModel
,
GPTModelFirstStage
,
GPTModelIntermediateStage
,
GPTModelLastStage
)
from
.distributed
import
DistributedDataParallel
from
.bert_model
import
BertModel
from
.gpt_model
import
GPTModel
from
.language_model
import
get_language_model
from
.module
import
FP16Module
from
.realm_model
import
ICTBertModel
from
.module
import
Float16Module
megatron/model/bert_model.py
View file @
aed2f75e
...
...
@@ -19,19 +19,16 @@ import torch
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.language_model
import
parallel_lm_logits
from
megatron.model.language_model
import
get_language_model
from
megatron.model
import
import_l
ayer
n
orm
from
megatron.model
import
L
ayer
N
orm
from
megatron.model.utils
import
openai_gelu
,
erf_gelu
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.utils
import
init_method_normal
from
megatron.model.utils
import
scaled_init_method_normal
from
.module
import
MegatronModule
def
bert_attention_mask_func
(
attention_scores
,
attention_mask
):
attention_scores
.
masked_fill_
(
attention_mask
,
-
10000.0
)
return
attention_scores
def
bert_extended_attention_mask
(
attention_mask
):
# We create a 3D attention mask from a 2D tensor mask.
# [b, 1, s]
...
...
@@ -77,13 +74,10 @@ class BertLMHead(MegatronModule):
args
=
get_args
()
self
.
bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
mpu_vocab_size
))
self
.
bias
.
tensor_model_parallel
=
True
self
.
bias
.
partition_dim
=
0
self
.
bias
.
stride
=
1
mpu
.
set_tensor_model_parallel_attributes
(
self
.
bias
,
True
,
0
,
1
)
self
.
parallel_output
=
parallel_output
self
.
dense
=
get_linear_layer
(
hidden_size
,
hidden_size
,
init_method
)
LayerNorm
=
import_layernorm
(
args
.
fp32_residual_connection
)
self
.
layernorm
=
LayerNorm
(
hidden_size
,
eps
=
layernorm_epsilon
)
self
.
gelu
=
torch
.
nn
.
functional
.
gelu
if
args
.
openai_gelu
:
...
...
@@ -127,31 +121,39 @@ def post_language_model_processing(lm_output, pooled_output,
return
lm_loss
,
binary_logits
class
BertModel
Base
(
MegatronModule
):
class
BertModel
(
MegatronModule
):
"""Bert Language model."""
def
__init__
(
self
,
num_tokentypes
=
2
,
add_binary_head
=
True
,
parallel_output
=
True
):
super
(
BertModelBase
,
self
).
__init__
()
def
__init__
(
self
,
num_tokentypes
=
2
,
add_binary_head
=
True
,
parallel_output
=
True
,
pre_process
=
True
,
post_process
=
True
):
super
(
BertModel
,
self
).
__init__
()
args
=
get_args
()
self
.
fp16_lm_cross_entropy
=
args
.
fp16_lm_cross_entropy
self
.
add_binary_head
=
add_binary_head
self
.
parallel_output
=
parallel_output
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
init_method
=
init_method_normal
(
args
.
init_method_std
)
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
)
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
attention_mask_func
=
bert_attention_mask_func
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
self
.
add_binary_head
,
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method
)
scaled_init_method
=
scaled_init_method
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
self
.
initialize_word_embeddings
(
init_method_normal
)
if
mpu
.
is_pipeline_last_stage
()
:
if
self
.
post_process
:
self
.
lm_head
=
BertLMHead
(
self
.
word_embeddings_weight
().
size
(
0
),
args
.
hidden_size
,
init_method
,
args
.
layernorm_epsilon
,
parallel_output
)
...
...
@@ -162,26 +164,30 @@ class BertModelBase(MegatronModule):
init_method
)
self
.
_binary_head_key
=
'binary_head'
def
set_input_tensor
(
self
,
input_tensor
):
"""See megatron.model.transformer.set_input_tensor()"""
self
.
language_model
.
set_input_tensor
(
input_tensor
)
def
forward
(
self
,
bert_model_input
,
attention_mask
,
tokentype_ids
=
None
,
lm_labels
=
None
):
extended_attention_mask
=
bert_extended_attention_mask
(
attention_mask
)
input_ids
=
bert_model_input
position_ids
=
bert_position_ids
(
input_ids
)
kwargs
=
{}
if
mpu
.
is_pipeline_first_stage
():
input_ids
=
bert_model_input
position_ids
=
bert_position_ids
(
input_ids
)
args
=
[
input_ids
,
position_ids
,
extended_attention_mask
]
kwargs
[
'tokentype_ids'
]
=
tokentype_ids
else
:
args
=
[
bert_model_input
,
extended_attention_mask
]
lm_output
=
self
.
language_model
(
*
args
,
**
kwargs
)
if
mpu
.
is_pipeline_last_stage
()
and
self
.
add_binary_head
:
lm_output
=
self
.
language_model
(
input_ids
,
position_ids
,
extended_attention_mask
,
tokentype_ids
=
tokentype_ids
)
if
self
.
post_process
and
self
.
add_binary_head
:
lm_output
,
pooled_output
=
lm_output
else
:
pooled_output
=
None
if
mpu
.
is_pipeline_last_stage
()
:
if
self
.
post_process
:
return
post_language_model_processing
(
lm_output
,
pooled_output
,
self
.
lm_head
,
self
.
binary_head
,
lm_labels
,
...
...
@@ -200,15 +206,15 @@ class BertModelBase(MegatronModule):
state_dict_
[
self
.
_language_model_key
]
\
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
if
mpu
.
is_pipeline_last_stage
()
:
if
self
.
post_process
:
state_dict_
[
self
.
_lm_head_key
]
\
=
self
.
lm_head
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
if
mpu
.
is_pipeline_last_stage
()
and
self
.
add_binary_head
:
if
self
.
post_process
and
self
.
add_binary_head
:
state_dict_
[
self
.
_binary_head_key
]
\
=
self
.
binary_head
.
state_dict
(
destination
,
prefix
,
keep_vars
)
# Save word_embeddings.
if
mpu
.
is_pipeline_last_stage
()
and
not
mpu
.
is_pipeline_first_stage
()
:
if
self
.
post_process
and
not
self
.
pre_process
:
state_dict_
[
self
.
_word_embeddings_for_head_key
]
\
=
self
.
word_embeddings
.
state_dict
(
destination
,
prefix
,
keep_vars
)
return
state_dict_
...
...
@@ -218,74 +224,13 @@ class BertModelBase(MegatronModule):
self
.
language_model
.
load_state_dict
(
state_dict
[
self
.
_language_model_key
],
strict
=
strict
)
if
mpu
.
is_pipeline_last_stage
()
:
if
self
.
post_process
:
self
.
lm_head
.
load_state_dict
(
state_dict
[
self
.
_lm_head_key
],
strict
=
strict
)
if
mpu
.
is_pipeline_last_stage
()
and
self
.
add_binary_head
:
if
self
.
post_process
and
self
.
add_binary_head
:
self
.
binary_head
.
load_state_dict
(
state_dict
[
self
.
_binary_head_key
],
strict
=
strict
)
# Load word_embeddings.
if
mpu
.
is_pipeline_last_stage
()
and
not
mpu
.
is_pipeline_first_stage
()
:
if
self
.
post_process
and
not
self
.
pre_process
:
self
.
word_embeddings
.
load_state_dict
(
state_dict
[
self
.
_word_embeddings_for_head_key
],
strict
=
strict
)
class
BertModel
(
BertModelBase
):
def
__init__
(
self
,
num_tokentypes
=
2
,
add_binary_head
=
True
,
parallel_output
=
True
):
super
(
BertModel
,
self
).
__init__
(
num_tokentypes
=
num_tokentypes
,
add_binary_head
=
add_binary_head
,
parallel_output
=
parallel_output
)
def
forward
(
self
,
input_ids
,
attention_mask
,
tokentype_ids
=
None
,
lm_labels
=
None
):
return
super
(
BertModel
,
self
).
forward
(
input_ids
,
attention_mask
,
tokentype_ids
=
tokentype_ids
,
lm_labels
=
lm_labels
)
class
BertModelFirstStage
(
BertModelBase
):
def
__init__
(
self
,
num_tokentypes
=
2
):
super
(
BertModelFirstStage
,
self
).
__init__
(
num_tokentypes
=
num_tokentypes
)
def
forward
(
self
,
input_ids
,
attention_mask
,
tokentype_ids
=
None
):
return
super
(
BertModelFirstStage
,
self
).
forward
(
input_ids
,
attention_mask
,
tokentype_ids
=
tokentype_ids
)
class
BertModelIntermediateStage
(
BertModelBase
):
def
__init__
(
self
,
num_tokentypes
=
2
):
super
(
BertModelIntermediateStage
,
self
).
__init__
(
num_tokentypes
=
num_tokentypes
)
def
forward
(
self
,
hidden_state
,
attention_mask
):
return
super
(
BertModelIntermediateStage
,
self
).
forward
(
hidden_state
,
attention_mask
)
class
BertModelLastStage
(
BertModelBase
):
def
__init__
(
self
,
num_tokentypes
=
2
,
add_binary_head
=
True
,
parallel_output
=
True
):
super
(
BertModelLastStage
,
self
).
__init__
(
num_tokentypes
=
num_tokentypes
,
add_binary_head
=
add_binary_head
,
parallel_output
=
parallel_output
)
def
forward
(
self
,
hidden_state
,
attention_mask
,
lm_labels
=
None
):
return
super
(
BertModelLastStage
,
self
).
forward
(
hidden_state
,
attention_mask
,
lm_labels
=
lm_labels
)
megatron/model/biencoder_model.py
0 → 100644
View file @
aed2f75e
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment