Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
7a77abd9
Commit
7a77abd9
authored
Jan 12, 2022
by
Vijay Korthikanti
Browse files
Phase1 merge: vit optimizations + dataset enhancements + scaled_softmax kernel
parent
9a8b89ac
Changes
19
Show whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
1012 additions
and
132 deletions
+1012
-132
megatron/arguments.py
megatron/arguments.py
+11
-2
megatron/data/data_samplers.py
megatron/data/data_samplers.py
+55
-13
megatron/data/image_folder.py
megatron/data/image_folder.py
+271
-0
megatron/data/vit_dataset.py
megatron/data/vit_dataset.py
+52
-31
megatron/fused_kernels/__init__.py
megatron/fused_kernels/__init__.py
+6
-0
megatron/fused_kernels/scaled_masked_softmax.h
megatron/fused_kernels/scaled_masked_softmax.h
+214
-2
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
+1
-1
megatron/fused_kernels/scaled_softmax.cpp
megatron/fused_kernels/scaled_softmax.cpp
+75
-0
megatron/fused_kernels/scaled_softmax_cuda.cu
megatron/fused_kernels/scaled_softmax_cuda.cu
+104
-0
megatron/initialize.py
megatron/initialize.py
+3
-3
megatron/model/distributed.py
megatron/model/distributed.py
+7
-0
megatron/model/fused_softmax.py
megatron/model/fused_softmax.py
+37
-4
megatron/model/vision/classification.py
megatron/model/vision/classification.py
+65
-0
megatron/model/vision/vit_backbone.py
megatron/model/vision/vit_backbone.py
+88
-67
megatron/mpu/__init__.py
megatron/mpu/__init__.py
+1
-0
megatron/mpu/initialize.py
megatron/mpu/initialize.py
+9
-0
megatron/optimizer/__init__.py
megatron/optimizer/__init__.py
+3
-3
megatron/training.py
megatron/training.py
+2
-1
pretrain_vit.py
pretrain_vit.py
+8
-5
No files found.
megatron/arguments.py
View file @
7a77abd9
...
@@ -835,11 +835,20 @@ def _add_vit_args(parser):
...
@@ -835,11 +835,20 @@ def _add_vit_args(parser):
group
.
add_argument
(
'--num-classes'
,
type
=
int
,
default
=
1000
,
group
.
add_argument
(
'--num-classes'
,
type
=
int
,
default
=
1000
,
help
=
'num of classes in vision classificaiton task'
)
help
=
'num of classes in vision classificaiton task'
)
group
.
add_argument
(
'--img-dim'
,
type
=
int
,
default
=
224
,
group
.
add_argument
(
'--img-h'
,
type
=
int
,
default
=
224
,
help
=
'Image size for vision classification task'
)
help
=
'Image height for vision classification task'
)
group
.
add_argument
(
'--img-w'
,
type
=
int
,
default
=
224
,
help
=
'Image height for vision classification task'
)
group
.
add_argument
(
'--num-channels'
,
type
=
int
,
default
=
3
,
group
.
add_argument
(
'--num-channels'
,
type
=
int
,
default
=
3
,
help
=
'Number of channels in input image data'
)
help
=
'Number of channels in input image data'
)
group
.
add_argument
(
'--patch-dim'
,
type
=
int
,
default
=
16
,
group
.
add_argument
(
'--patch-dim'
,
type
=
int
,
default
=
16
,
help
=
'patch dimension used in vit'
)
help
=
'patch dimension used in vit'
)
group
.
add_argument
(
'--classes-fraction'
,
type
=
float
,
default
=
1.0
,
help
=
'training with fraction of classes.'
)
group
.
add_argument
(
'--data-per-class-fraction'
,
type
=
float
,
default
=
1.0
,
help
=
'training with fraction of data per class.'
)
group
.
add_argument
(
'--no-data-sharding'
,
action
=
'store_false'
,
help
=
'Disable data sharding.'
,
dest
=
'data_sharding'
)
return
parser
return
parser
megatron/data/data_samplers.py
View file @
7a77abd9
...
@@ -16,8 +16,10 @@
...
@@ -16,8 +16,10 @@
"""Dataloaders."""
"""Dataloaders."""
import
torch
import
random
import
random
import
torch
import
numpy
as
np
from
torch.utils.data
import
Dataset
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron
import
mpu
...
@@ -39,11 +41,13 @@ def build_pretraining_data_loader(dataset, consumed_samples):
...
@@ -39,11 +41,13 @@ def build_pretraining_data_loader(dataset, consumed_samples):
data_parallel_size
=
mpu
.
get_data_parallel_world_size
())
data_parallel_size
=
mpu
.
get_data_parallel_world_size
())
elif
args
.
dataloader_type
==
'cyclic'
:
elif
args
.
dataloader_type
==
'cyclic'
:
batch_sampler
=
MegatronPretrainingRandomSampler
(
batch_sampler
=
MegatronPretrainingRandomSampler
(
dataset
,
total_samples
=
len
(
dataset
),
total_samples
=
len
(
dataset
),
consumed_samples
=
consumed_samples
,
consumed_samples
=
consumed_samples
,
micro_batch_size
=
args
.
micro_batch_size
,
micro_batch_size
=
args
.
micro_batch_size
,
data_parallel_rank
=
mpu
.
get_data_parallel_rank
(),
data_parallel_rank
=
mpu
.
get_data_parallel_rank
(),
data_parallel_size
=
mpu
.
get_data_parallel_world_size
())
data_parallel_size
=
mpu
.
get_data_parallel_world_size
(),
data_sharding
=
args
.
data_sharding
)
else
:
else
:
raise
Exception
(
'{} dataloader type is not supported.'
.
format
(
raise
Exception
(
'{} dataloader type is not supported.'
.
format
(
args
.
dataloader_type
))
args
.
dataloader_type
))
...
@@ -103,16 +107,40 @@ class MegatronPretrainingSampler:
...
@@ -103,16 +107,40 @@ class MegatronPretrainingSampler:
yield
batch
[
start_idx
:
end_idx
]
yield
batch
[
start_idx
:
end_idx
]
class
RandomSeedDataset
(
Dataset
):
def
__init__
(
self
,
dataset
):
args
=
get_args
()
self
.
base_seed
=
args
.
seed
self
.
curr_seed
=
args
.
seed
self
.
dataset
=
dataset
def
__len__
(
self
):
return
len
(
self
.
dataset
)
def
set_epoch
(
self
,
epoch
):
self
.
curr_seed
=
self
.
base_seed
+
epoch
def
__getitem__
(
self
,
idx
):
seed
=
idx
+
self
.
curr_seed
torch
.
manual_seed
(
seed
)
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
return
self
.
dataset
[
idx
]
class
MegatronPretrainingRandomSampler
:
class
MegatronPretrainingRandomSampler
:
def
__init__
(
self
,
total_samples
,
consumed_samples
,
micro_batch_size
,
def
__init__
(
self
,
dataset
,
total_samples
,
consumed_samples
,
micro_batch_size
,
data_parallel_rank
,
data_parallel_size
):
data_parallel_rank
,
data_parallel_size
,
data_sharding
):
# Keep a copy of input params for later use.
# Keep a copy of input params for later use.
self
.
dataset
=
dataset
self
.
total_samples
=
total_samples
self
.
total_samples
=
total_samples
self
.
consumed_samples
=
consumed_samples
self
.
consumed_samples
=
consumed_samples
self
.
micro_batch_size
=
micro_batch_size
self
.
micro_batch_size
=
micro_batch_size
self
.
data_parallel_rank
=
data_parallel_rank
self
.
data_parallel_rank
=
data_parallel_rank
self
.
data_parallel_size
=
data_parallel_size
self
.
data_parallel_size
=
data_parallel_size
self
.
data_sharding
=
data_sharding
self
.
micro_batch_times_data_parallel_size
=
\
self
.
micro_batch_times_data_parallel_size
=
\
self
.
micro_batch_size
*
data_parallel_size
self
.
micro_batch_size
*
data_parallel_size
self
.
last_batch_size
=
\
self
.
last_batch_size
=
\
...
@@ -136,7 +164,11 @@ class MegatronPretrainingRandomSampler:
...
@@ -136,7 +164,11 @@ class MegatronPretrainingRandomSampler:
current_epoch_samples
=
self
.
consumed_samples
%
active_total_samples
current_epoch_samples
=
self
.
consumed_samples
%
active_total_samples
assert
current_epoch_samples
%
self
.
micro_batch_times_data_parallel_size
==
0
assert
current_epoch_samples
%
self
.
micro_batch_times_data_parallel_size
==
0
if
isinstance
(
dataset
,
RandomSeedDataset
):
self
.
dataset
.
set_epoch
(
self
.
epoch
)
# data sharding and random sampling
# data sharding and random sampling
if
self
.
data_sharding
:
bucket_size
=
(
self
.
total_samples
//
self
.
micro_batch_times_data_parallel_size
)
\
bucket_size
=
(
self
.
total_samples
//
self
.
micro_batch_times_data_parallel_size
)
\
*
self
.
micro_batch_size
*
self
.
micro_batch_size
bucket_offset
=
current_epoch_samples
//
self
.
data_parallel_size
bucket_offset
=
current_epoch_samples
//
self
.
data_parallel_size
...
@@ -146,6 +178,16 @@ class MegatronPretrainingRandomSampler:
...
@@ -146,6 +178,16 @@ class MegatronPretrainingRandomSampler:
g
.
manual_seed
(
self
.
epoch
)
g
.
manual_seed
(
self
.
epoch
)
random_idx
=
torch
.
randperm
(
bucket_size
,
generator
=
g
).
tolist
()
random_idx
=
torch
.
randperm
(
bucket_size
,
generator
=
g
).
tolist
()
idx_range
=
[
start_idx
+
x
for
x
in
random_idx
[
bucket_offset
:]]
idx_range
=
[
start_idx
+
x
for
x
in
random_idx
[
bucket_offset
:]]
else
:
full_bucket_size
=
(
self
.
total_samples
//
self
.
micro_batch_size
)
\
*
self
.
micro_batch_size
full_bucket_offset
=
current_epoch_samples
g
=
torch
.
Generator
()
g
.
manual_seed
(
self
.
epoch
)
idx_range_total
=
\
torch
.
randperm
(
full_bucket_size
,
generator
=
g
).
tolist
()
idx_range_active
=
idx_range_total
[
full_bucket_offset
:]
idx_range
=
idx_range_active
[
self
.
data_parallel_rank
::
self
.
data_parallel_size
]
batch
=
[]
batch
=
[]
# Last batch if not complete will be dropped.
# Last batch if not complete will be dropped.
...
...
megatron/data/image_folder.py
0 → 100644
View file @
7a77abd9
# code taken from pytorch
# added support for classes_fraction and data_per_class_fraction
from
torchvision.datasets
import
VisionDataset
from
PIL
import
Image
import
os
import
os.path
from
typing
import
Any
,
Callable
,
cast
,
Dict
,
List
,
Optional
,
Tuple
import
numpy
as
np
def
has_file_allowed_extension
(
filename
:
str
,
extensions
:
Tuple
[
str
,
...])
->
bool
:
"""Checks if a file is an allowed extension.
Args:
filename (string): path to a file
extensions (tuple of strings): extensions to consider (lowercase)
Returns:
bool: True if the filename ends with one of given extensions
"""
return
filename
.
lower
().
endswith
(
extensions
)
def
is_image_file
(
filename
:
str
)
->
bool
:
"""Checks if a file is an allowed image extension.
Args:
filename (string): path to a file
Returns:
bool: True if the filename ends with a known image extension
"""
return
has_file_allowed_extension
(
filename
,
IMG_EXTENSIONS
)
def
make_dataset
(
directory
:
str
,
class_to_idx
:
Dict
[
str
,
int
],
data_per_class_fraction
:
float
,
extensions
:
Optional
[
Tuple
[
str
,
...]]
=
None
,
is_valid_file
:
Optional
[
Callable
[[
str
],
bool
]]
=
None
,
)
->
List
[
Tuple
[
str
,
int
]]:
"""Generates a list of samples of a form (path_to_sample, class).
Args:
directory (str): root dataset directory
class_to_idx (Dict[str, int]): dictionary mapping class name to class index
extensions (optional): A list of allowed extensions.
Either extensions or is_valid_file should be passed. Defaults to None.
is_valid_file (optional): A function that takes path of a file
and checks if the file is a valid file
(used to check of corrupt files) both extensions and
is_valid_file should not be passed. Defaults to None.
Raises:
ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None.
Returns:
List[Tuple[str, int]]: samples of a form (path_to_sample, class)
"""
instances
=
[]
directory
=
os
.
path
.
expanduser
(
directory
)
both_none
=
extensions
is
None
and
is_valid_file
is
None
both_something
=
extensions
is
not
None
and
is_valid_file
is
not
None
if
both_none
or
both_something
:
raise
ValueError
(
"Both extensions and is_valid_file cannot be None or not None at the same time"
)
if
extensions
is
not
None
:
def
is_valid_file
(
x
:
str
)
->
bool
:
return
has_file_allowed_extension
(
x
,
cast
(
Tuple
[
str
,
...],
extensions
))
is_valid_file
=
cast
(
Callable
[[
str
],
bool
],
is_valid_file
)
for
target_class
in
sorted
(
class_to_idx
.
keys
()):
class_index
=
class_to_idx
[
target_class
]
target_dir
=
os
.
path
.
join
(
directory
,
target_class
)
if
not
os
.
path
.
isdir
(
target_dir
):
continue
local_instances
=
[]
for
root
,
_
,
fnames
in
sorted
(
os
.
walk
(
target_dir
,
followlinks
=
True
)):
for
fname
in
sorted
(
fnames
):
path
=
os
.
path
.
join
(
root
,
fname
)
if
is_valid_file
(
path
):
item
=
path
,
class_index
local_instances
.
append
(
item
)
instances
.
extend
(
local_instances
[
0
:
int
(
len
(
local_instances
)
*
data_per_class_fraction
)])
return
instances
class
DatasetFolder
(
VisionDataset
):
"""A generic data loader where the samples are arranged in this way: ::
root/class_x/xxx.ext
root/class_x/xxy.ext
root/class_x/[...]/xxz.ext
root/class_y/123.ext
root/class_y/nsdf3.ext
root/class_y/[...]/asd932_.ext
Args:
root (string): Root directory path.
loader (callable): A function to load a sample given its path.
extensions (tuple[string]): A list of allowed extensions.
both extensions and is_valid_file should not be passed.
transform (callable, optional): A function/transform that takes in
a sample and returns a transformed version.
E.g, ``transforms.RandomCrop`` for images.
target_transform (callable, optional): A function/transform that takes
in the target and transforms it.
is_valid_file (callable, optional): A function that takes path of a file
and check if the file is a valid file (used to check of corrupt files)
both extensions and is_valid_file should not be passed.
Attributes:
classes (list): List of the class names sorted alphabetically.
class_to_idx (dict): Dict with items (class_name, class_index).
samples (list): List of (sample path, class_index) tuples
targets (list): The class_index value for each image in the dataset
"""
def
__init__
(
self
,
root
:
str
,
loader
:
Callable
[[
str
],
Any
],
extensions
:
Optional
[
Tuple
[
str
,
...]]
=
None
,
transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
classes_fraction
=
1.0
,
data_per_class_fraction
=
1.0
,
is_valid_file
:
Optional
[
Callable
[[
str
],
bool
]]
=
None
,
)
->
None
:
super
(
DatasetFolder
,
self
).
__init__
(
root
,
transform
=
transform
,
target_transform
=
target_transform
)
self
.
classes_fraction
=
classes_fraction
self
.
data_per_class_fraction
=
data_per_class_fraction
classes
,
class_to_idx
=
self
.
_find_classes
(
self
.
root
)
samples
=
self
.
make_dataset
(
self
.
root
,
class_to_idx
,
self
.
data_per_class_fraction
,
extensions
,
is_valid_file
)
if
len
(
samples
)
==
0
:
msg
=
"Found 0 files in subfolders of: {}
\n
"
.
format
(
self
.
root
)
if
extensions
is
not
None
:
msg
+=
"Supported extensions are: {}"
.
format
(
","
.
join
(
extensions
))
raise
RuntimeError
(
msg
)
self
.
loader
=
loader
self
.
extensions
=
extensions
self
.
total
=
len
(
samples
)
self
.
classes
=
classes
self
.
class_to_idx
=
class_to_idx
self
.
samples
=
samples
self
.
targets
=
[
s
[
1
]
for
s
in
samples
]
@
staticmethod
def
make_dataset
(
directory
:
str
,
class_to_idx
:
Dict
[
str
,
int
],
data_per_class_fraction
:
float
,
extensions
:
Optional
[
Tuple
[
str
,
...]]
=
None
,
is_valid_file
:
Optional
[
Callable
[[
str
],
bool
]]
=
None
,
)
->
List
[
Tuple
[
str
,
int
]]:
return
make_dataset
(
directory
,
class_to_idx
,
data_per_class_fraction
,
extensions
=
extensions
,
is_valid_file
=
is_valid_file
)
def
_find_classes
(
self
,
dir
:
str
)
->
Tuple
[
List
[
str
],
Dict
[
str
,
int
]]:
"""
Finds the class folders in a dataset.
Args:
dir (string): Root directory path.
Returns:
tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
Ensures:
No class is a subdirectory of another.
"""
all_classes
=
[
d
.
name
for
d
in
os
.
scandir
(
dir
)
if
d
.
is_dir
()]
classes
=
all_classes
[
0
:
int
(
len
(
all_classes
)
*
self
.
classes_fraction
)]
classes
.
sort
()
class_to_idx
=
{
cls_name
:
i
for
i
,
cls_name
in
enumerate
(
classes
)}
return
classes
,
class_to_idx
def
__getitem__
(
self
,
index
:
int
)
->
Tuple
[
Any
,
Any
]:
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
curr_index
=
index
for
x
in
range
(
self
.
total
):
try
:
path
,
target
=
self
.
samples
[
curr_index
]
sample
=
self
.
loader
(
path
)
break
except
Exception
as
e
:
curr_index
=
np
.
random
.
randint
(
0
,
self
.
total
)
if
self
.
transform
is
not
None
:
sample
=
self
.
transform
(
sample
)
if
self
.
target_transform
is
not
None
:
target
=
self
.
target_transform
(
target
)
return
sample
,
target
def
__len__
(
self
)
->
int
:
return
len
(
self
.
samples
)
IMG_EXTENSIONS
=
(
'.jpg'
,
'.jpeg'
,
'.png'
,
'.ppm'
,
'.bmp'
,
'.pgm'
,
'.tif'
,
'.tiff'
,
'.webp'
)
def
pil_loader
(
path
:
str
)
->
Image
.
Image
:
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with
open
(
path
,
'rb'
)
as
f
:
img
=
Image
.
open
(
f
)
return
img
.
convert
(
'RGB'
)
# TODO: specify the return type
def
accimage_loader
(
path
:
str
)
->
Any
:
import
accimage
try
:
return
accimage
.
Image
(
path
)
except
IOError
:
# Potentially a decoding problem, fall back to PIL.Image
return
pil_loader
(
path
)
def
default_loader
(
path
:
str
)
->
Any
:
from
torchvision
import
get_image_backend
if
get_image_backend
()
==
'accimage'
:
return
accimage_loader
(
path
)
else
:
return
pil_loader
(
path
)
class
ImageFolder
(
DatasetFolder
):
"""A generic data loader where the images are arranged in this way: ::
root/dog/xxx.png
root/dog/xxy.png
root/dog/[...]/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/[...]/asd932_.png
Args:
root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
is_valid_file (callable, optional): A function that takes path of an Image file
and check if the file is a valid file (used to check of corrupt files)
Attributes:
classes (list): List of the class names sorted alphabetically.
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples
"""
def
__init__
(
self
,
root
:
str
,
transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
classes_fraction
=
1.0
,
data_per_class_fraction
=
1.0
,
loader
:
Callable
[[
str
],
Any
]
=
default_loader
,
is_valid_file
:
Optional
[
Callable
[[
str
],
bool
]]
=
None
,
):
super
(
ImageFolder
,
self
).
__init__
(
root
,
loader
,
IMG_EXTENSIONS
if
is_valid_file
is
None
else
None
,
transform
=
transform
,
target_transform
=
target_transform
,
classes_fraction
=
classes_fraction
,
data_per_class_fraction
=
data_per_class_fraction
,
is_valid_file
=
is_valid_file
)
self
.
imgs
=
self
.
samples
megatron/data/vit_dataset.py
View file @
7a77abd9
...
@@ -13,46 +13,67 @@
...
@@ -13,46 +13,67 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
os
import
os
import
random
import
numpy
as
np
import
torch
import
torch
from
torchvision
import
datasets
,
transforms
import
torchvision.transforms
as
T
from
torchvision
import
datasets
from
megatron
import
get_args
from
megatron.data.image_folder
import
ImageFolder
from
megatron.data.autoaugment
import
ImageNetPolicy
from
megatron.data.autoaugment
import
ImageNetPolicy
from
megatron.data.data_samplers
import
RandomSeedDataset
class
ClassificationTransform
():
def
__init__
(
self
,
image_size
,
train
=
True
):
args
=
get_args
()
assert
args
.
fp16
or
args
.
bf16
self
.
data_type
=
torch
.
half
if
args
.
fp16
else
torch
.
bfloat16
if
train
:
self
.
transform
=
T
.
Compose
([
T
.
RandomResizedCrop
(
image_size
),
T
.
RandomHorizontalFlip
(),
T
.
ColorJitter
(
0.4
,
0.4
,
0.4
,
0.1
),
ImageNetPolicy
(),
T
.
ToTensor
(),
T
.
Normalize
(
*
self
.
mean_std
),
T
.
ConvertImageDtype
(
self
.
data_type
)
])
else
:
self
.
transform
=
T
.
Compose
([
T
.
Resize
(
image_size
),
T
.
CenterCrop
(
image_size
),
T
.
ToTensor
(),
T
.
Normalize
((
0.485
,
0.456
,
0.406
),
(
0.229
,
0.224
,
0.225
)),
T
.
ConvertImageDtype
(
self
.
data_type
)
])
def
build_train_valid_datasets
(
data_path
,
crop_size
=
224
,
color_jitter
=
True
):
def
__call__
(
self
,
input
):
output
=
self
.
transform
(
input
)
return
output
def
build_train_valid_datasets
(
data_path
,
image_size
=
224
):
args
=
get_args
()
train_transform
=
ClassificationTransform
(
image_size
)
val_transform
=
ClassificationTransform
(
image_size
,
train
=
False
)
# training dataset
# training dataset
train_data_path
=
os
.
path
.
join
(
data_path
[
0
],
"train"
)
train_data_path
=
data_path
[
0
]
normalize
=
transforms
.
Normalize
(
mean
=
[
0.5
,
0.5
,
0.5
],
std
=
[
0.5
,
0.5
,
0.5
])
train_data
=
ImageFolder
(
process
=
[
root
=
train_data_path
,
transforms
.
RandomResizedCrop
(
crop_size
),
transform
=
train_transform
,
transforms
.
RandomHorizontalFlip
(),
classes_fraction
=
args
.
classes_fraction
,
]
data_per_class_fraction
=
args
.
data_per_class_fraction
if
color_jitter
:
process
+=
[
transforms
.
ColorJitter
(
brightness
=
0.4
,
contrast
=
0.4
,
saturation
=
0.4
,
hue
=
0.1
)
]
fp16_t
=
transforms
.
ConvertImageDtype
(
torch
.
half
)
process
+=
[
ImageNetPolicy
(),
transforms
.
ToTensor
(),
normalize
,
fp16_t
]
transform_train
=
transforms
.
Compose
(
process
)
train_data
=
datasets
.
ImageFolder
(
root
=
train_data_path
,
transform
=
transform_train
)
)
train_data
=
RandomSeedDataset
(
train_data
)
# validation dataset
# validation dataset
val_data_path
=
os
.
path
.
join
(
data_path
[
0
],
"val"
)
val_data_path
=
data_path
[
1
]
transform_val
=
transforms
.
Compose
(
val_data
=
ImageFolder
(
[
root
=
val_data_path
,
transforms
.
Resize
(
crop_size
),
transform
=
val_transform
transforms
.
CenterCrop
(
crop_size
),
transforms
.
ToTensor
(),
normalize
,
fp16_t
]
)
val_data
=
datasets
.
ImageFolder
(
root
=
val_data_path
,
transform
=
transform_val
)
)
val_data
=
RandomSeedDataset
(
val_data
)
return
train_data
,
val_data
return
train_data
,
val_data
megatron/fused_kernels/__init__.py
View file @
7a77abd9
...
@@ -78,6 +78,12 @@ def load(args):
...
@@ -78,6 +78,12 @@ def load(args):
scaled_masked_softmax_cuda
=
_cpp_extention_load_helper
(
scaled_masked_softmax_cuda
=
_cpp_extention_load_helper
(
"scaled_masked_softmax_cuda"
,
sources
,
extra_cuda_flags
)
"scaled_masked_softmax_cuda"
,
sources
,
extra_cuda_flags
)
# Softmax
sources
=
[
srcpath
/
'scaled_softmax.cpp'
,
srcpath
/
'scaled_softmax_cuda.cu'
]
scaled_softmax_cuda
=
_cpp_extention_load_helper
(
"scaled_softmax_cuda"
,
sources
,
extra_cuda_flags
)
# =================================
# =================================
# Mixed precision fused layer norm.
# Mixed precision fused layer norm.
# =================================
# =================================
...
...
megatron/fused_kernels/scaled_masked_softmax.h
View file @
7a77abd9
...
@@ -90,6 +90,117 @@ __device__ __forceinline__ void warp_reduce(acc_t* sum) {
...
@@ -90,6 +90,117 @@ __device__ __forceinline__ void warp_reduce(acc_t* sum) {
}
}
}
}
/*
* Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling
*/
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
log2_elements
>
__global__
void
scaled_softmax_warp_forward
(
output_t
*
dst
,
const
input_t
*
src
,
const
acc_t
scale
,
int
micro_batch_size
,
int
element_count
)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_forward_kernel.
constexpr
int
next_power_of_two
=
1
<<
log2_elements
;
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
ELEMENTS_PER_LDG_STG
=
(
WARP_ITERATIONS
<
4
)
?
1
:
4
;
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
int
first_batch
=
(
blockDim
.
y
*
(
blockIdx
.
x
+
gridDim
.
x
*
(
blockIdx
.
y
+
gridDim
.
y
*
blockIdx
.
z
))
+
threadIdx
.
y
)
*
WARP_BATCH
;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int
local_batches
=
micro_batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
local_batches
=
WARP_BATCH
;
// there might be multiple batches per warp. compute the index within the batch
int
local_idx
=
threadIdx
.
x
;
src
+=
first_batch
*
element_count
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
dst
+=
first_batch
*
element_count
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
// load data from global memory
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
input_t
temp_data
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
int
itr_idx
=
i
*
element_count
+
it
*
WARP_SIZE
;
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_data
,
src
+
itr_idx
);
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
elements
[
i
][
it
+
element
]
=
(
acc_t
)
temp_data
[
element
]
*
scale
;
}
}
else
{
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
elements
[
i
][
it
+
element
]
=
-
std
::
numeric_limits
<
acc_t
>::
infinity
();
}
}
}
}
// compute max_value
acc_t
max_value
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
max_value
[
i
]
=
elements
[
i
][
0
];
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
max_value
[
i
]
=
(
max_value
[
i
]
>
elements
[
i
][
it
])
?
max_value
[
i
]
:
elements
[
i
][
it
];
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Max
>
(
max_value
);
acc_t
sum
[
WARP_BATCH
]
{
0.0
f
};
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
elements
[
i
][
it
]
=
std
::
exp
((
elements
[
i
][
it
]
-
max_value
[
i
]));
sum
[
i
]
+=
elements
[
i
][
it
];
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Add
>
(
sum
);
// store result
output_t
out
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
out
[
element
]
=
elements
[
i
][
it
+
element
]
/
sum
[
i
];
}
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
dst
+
i
*
element_count
+
it
*
WARP_SIZE
,
out
);
}
else
{
break
;
}
}
}
}
/*
/*
* Extended softmax (from native aten pytorch) with following additional features
* Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling
* 1) input scaling
...
@@ -326,6 +437,98 @@ int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int att
...
@@ -326,6 +437,98 @@ int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int att
return
batches_per_block
;
return
batches_per_block
;
}
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
void
dispatch_scaled_softmax_forward
(
output_t
*
dst
,
const
input_t
*
src
,
const
input_t
scale
,
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
)
{
TORCH_INTERNAL_ASSERT
(
key_seq_len
>=
0
&&
key_seq_len
<=
4096
);
if
(
key_seq_len
==
0
)
{
return
;
}
else
{
int
log2_elements
=
log2_ceil
(
key_seq_len
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
batch_count
=
batches
*
attn_heads
*
query_seq_len
;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
TORCH_INTERNAL_ASSERT
(
query_seq_len
%
batches_per_block
==
0
);
dim3
blocks
(
query_seq_len
/
batches_per_block
,
attn_heads
,
batches
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch
(
log2_elements
)
{
case
0
:
// 1
scaled_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
0
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
1
:
// 2
scaled_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
1
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
2
:
// 4
scaled_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
3
:
// 8
scaled_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
3
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
4
:
// 16
scaled_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
4
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
5
:
// 32
scaled_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
5
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
6
:
// 64
scaled_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
6
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
7
:
// 128
scaled_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
7
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
8
:
// 256
scaled_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
8
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
9
:
// 512
scaled_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
9
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
10
:
// 1024
scaled_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
10
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
11
:
// 2048
scaled_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
11
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
12
:
// 4096
scaled_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
12
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
key_seq_len
);
break
;
default:
break
;
}
}
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
void
dispatch_scaled_masked_softmax_forward
(
void
dispatch_scaled_masked_softmax_forward
(
output_t
*
dst
,
output_t
*
dst
,
...
@@ -338,7 +541,7 @@ void dispatch_scaled_masked_softmax_forward(
...
@@ -338,7 +541,7 @@ void dispatch_scaled_masked_softmax_forward(
int
attn_heads
,
int
attn_heads
,
int
pad_batches
)
int
pad_batches
)
{
{
TORCH_INTERNAL_ASSERT
(
key_seq_len
>=
0
&&
key_seq_len
<=
2048
);
TORCH_INTERNAL_ASSERT
(
key_seq_len
>=
0
&&
key_seq_len
<=
4096
);
if
(
key_seq_len
==
0
)
{
if
(
key_seq_len
==
0
)
{
return
;
return
;
}
else
{
}
else
{
...
@@ -410,6 +613,10 @@ void dispatch_scaled_masked_softmax_forward(
...
@@ -410,6 +613,10 @@ void dispatch_scaled_masked_softmax_forward(
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
11
>
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
11
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
break
;
case
12
:
// 4096
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
12
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
default:
default:
break
;
break
;
}
}
...
@@ -427,7 +634,7 @@ void dispatch_scaled_masked_softmax_backward(
...
@@ -427,7 +634,7 @@ void dispatch_scaled_masked_softmax_backward(
int
batches
,
int
batches
,
int
attn_heads
)
int
attn_heads
)
{
{
TORCH_INTERNAL_ASSERT
(
key_seq_len
>=
0
&&
key_seq_len
<=
2048
);
TORCH_INTERNAL_ASSERT
(
key_seq_len
>=
0
&&
key_seq_len
<=
4096
);
if
(
key_seq_len
==
0
)
{
if
(
key_seq_len
==
0
)
{
return
;
return
;
}
else
{
}
else
{
...
@@ -498,6 +705,11 @@ void dispatch_scaled_masked_softmax_backward(
...
@@ -498,6 +705,11 @@ void dispatch_scaled_masked_softmax_backward(
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
11
>
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
11
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
break
;
case
12
:
// 4096
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
12
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
default:
default:
break
;
break
;
}
}
...
...
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
View file @
7a77abd9
...
@@ -44,7 +44,7 @@ torch::Tensor fwd_cuda(
...
@@ -44,7 +44,7 @@ torch::Tensor fwd_cuda(
const
int
attn_heads
=
input
.
size
(
1
);
const
int
attn_heads
=
input
.
size
(
1
);
const
int
query_seq_len
=
input
.
size
(
2
);
const
int
query_seq_len
=
input
.
size
(
2
);
const
int
key_seq_len
=
input
.
size
(
3
);
const
int
key_seq_len
=
input
.
size
(
3
);
TORCH_INTERNAL_ASSERT
(
key_seq_len
<=
2048
);
TORCH_INTERNAL_ASSERT
(
key_seq_len
<=
4096
);
TORCH_INTERNAL_ASSERT
(
query_seq_len
>
1
);
TORCH_INTERNAL_ASSERT
(
query_seq_len
>
1
);
TORCH_INTERNAL_ASSERT
(
pad_batches
==
1
||
pad_batches
==
batches
);
TORCH_INTERNAL_ASSERT
(
pad_batches
==
1
||
pad_batches
==
batches
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
1
)
==
1
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
1
)
==
1
);
...
...
megatron/fused_kernels/scaled_softmax.cpp
0 → 100644
View file @
7a77abd9
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace
multihead_attn
{
namespace
fused_softmax
{
namespace
scaled_softmax
{
torch
::
Tensor
fwd_cuda
(
torch
::
Tensor
const
&
input
,
float
scale_factor
);
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
softmax_results
,
float
scale_factor
);
torch
::
Tensor
fwd
(
torch
::
Tensor
const
&
input
,
float
scale_factor
)
{
AT_ASSERTM
(
input
.
dim
()
==
4
,
"expected 4D tensor"
);
AT_ASSERTM
((
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
input
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
return
fwd_cuda
(
input
,
scale_factor
);
}
torch
::
Tensor
bwd
(
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
softmax_results
,
float
scale_factor
)
{
AT_ASSERTM
(
output_grads
.
dim
()
==
4
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
4
,
"expected 3D tensor"
);
AT_ASSERTM
((
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
AT_ASSERTM
((
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
return
bwd_cuda
(
output_grads
,
softmax_results
,
scale_factor
);
}
}
// end namespace scaled_softmax
}
// end namespace fused_softmax
}
// end namespace multihead_attn
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
multihead_attn
::
fused_softmax
::
scaled_softmax
::
fwd
,
"Self Multihead Attention scaled, softmax -- Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
fused_softmax
::
scaled_softmax
::
bwd
,
"Self Multihead Attention scaled, softmax -- Backward."
);
}
megatron/fused_kernels/scaled_softmax_cuda.cu
0 → 100644
View file @
7a77abd9
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "scaled_masked_softmax.h"
#include "type_shim.h"
namespace
multihead_attn
{
namespace
fused_softmax
{
namespace
scaled_softmax
{
torch
::
Tensor
fwd_cuda
(
torch
::
Tensor
const
&
input
,
float
scale_factor
)
{
// input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const
int
batches
=
input
.
size
(
0
);
const
int
attn_heads
=
input
.
size
(
1
);
const
int
query_seq_len
=
input
.
size
(
2
);
const
int
key_seq_len
=
input
.
size
(
3
);
TORCH_INTERNAL_ASSERT
(
key_seq_len
<=
4096
);
TORCH_INTERNAL_ASSERT
(
query_seq_len
>
1
);
// Output
auto
act_options
=
input
.
options
().
requires_grad
(
false
);
torch
::
Tensor
softmax_results
=
torch
::
empty
({
batches
,
attn_heads
,
query_seq_len
,
key_seq_len
},
act_options
);
// Softmax Intermediate Result Ptr
void
*
input_ptr
=
static_cast
<
void
*>
(
input
.
data_ptr
());
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
DISPATCH_HALF_AND_BFLOAT
(
input
.
scalar_type
(),
"dispatch_scaled_softmax_forward"
,
dispatch_scaled_softmax_forward
<
scalar_t
,
scalar_t
,
float
>
(
reinterpret_cast
<
scalar_t
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
scalar_t
*>
(
input_ptr
),
scale_factor
,
query_seq_len
,
key_seq_len
,
batches
,
attn_heads
);
);
return
softmax_results
;
}
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
const
&
output_grads_
,
torch
::
Tensor
const
&
softmax_results_
,
float
scale_factor
)
{
auto
output_grads
=
output_grads_
.
contiguous
();
auto
softmax_results
=
softmax_results_
.
contiguous
();
//output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const
int
batches
=
output_grads
.
size
(
0
);
const
int
attn_heads
=
output_grads
.
size
(
1
);
const
int
query_seq_len
=
output_grads
.
size
(
2
);
const
int
key_seq_len
=
output_grads
.
size
(
3
);
void
*
output_grads_ptr
=
static_cast
<
void
*>
(
output_grads
.
data_ptr
());
//Softmax Grad
DISPATCH_HALF_AND_BFLOAT
(
output_grads_
.
scalar_type
(),
"dispatch_scaled_masked_softmax_backward"
,
dispatch_scaled_masked_softmax_backward
<
scalar_t
,
scalar_t
,
float
>
(
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
reinterpret_cast
<
scalar_t
const
*>
(
softmax_results
.
data_ptr
()),
scale_factor
,
query_seq_len
,
key_seq_len
,
batches
,
attn_heads
);
);
//backward pass is completely in-place
return
output_grads
;
}
}
}
}
megatron/initialize.py
View file @
7a77abd9
...
@@ -118,7 +118,7 @@ def _compile_dependencies():
...
@@ -118,7 +118,7 @@ def _compile_dependencies():
args
.
micro_batch_size
args
.
micro_batch_size
# Constraints on sequence length and attn_batch_size to enable warp based
# Constraints on sequence length and attn_batch_size to enable warp based
# optimization and upper triangular optimization (for causal mask)
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint
=
seq_len
>
16
and
seq_len
<=
2048
and
\
custom_kernel_constraint
=
seq_len
>
16
and
seq_len
<=
4096
and
\
seq_len
%
4
==
0
and
attn_batch_size
%
4
==
0
seq_len
%
4
==
0
and
attn_batch_size
%
4
==
0
# Print a warning.
# Print a warning.
if
not
((
args
.
fp16
or
args
.
bf16
)
and
if
not
((
args
.
fp16
or
args
.
bf16
)
and
...
@@ -206,8 +206,8 @@ def _init_autoresume():
...
@@ -206,8 +206,8 @@ def _init_autoresume():
def
_set_random_seed
(
seed_
):
def
_set_random_seed
(
seed_
):
"""Set random seed for reproducability."""
"""Set random seed for reproducability."""
if
seed_
is
not
None
and
seed_
>
0
:
if
seed_
is
not
None
and
seed_
>
0
:
# Ensure that different pipeline MP stages get different seeds.
# Ensure that different pipeline MP stages
and different data parallel ranks
get different seeds.
seed
=
seed_
+
(
100
*
mpu
.
get_pipeline_model_parallel_rank
())
seed
=
seed_
+
(
100
*
mpu
.
get_pipeline_model_parallel_rank
())
+
(
10
*
mpu
.
get_data_parallel_rank
())
random
.
seed
(
seed
)
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
manual_seed
(
seed
)
...
...
megatron/model/distributed.py
View file @
7a77abd9
...
@@ -185,6 +185,13 @@ class DistributedDataParallel(DistributedDataParallelBase):
...
@@ -185,6 +185,13 @@ class DistributedDataParallel(DistributedDataParallelBase):
buffer_
.
zero
()
buffer_
.
zero
()
def
broadcast_params
(
self
):
for
param
in
self
.
module
.
parameters
():
torch
.
distributed
.
broadcast
(
param
.
data
,
src
=
mpu
.
get_data_parallel_src_rank
(),
group
=
mpu
.
get_data_parallel_group
())
def
allreduce_gradients
(
self
):
def
allreduce_gradients
(
self
):
"""Reduce gradients across data parallel ranks."""
"""Reduce gradients across data parallel ranks."""
# If we have buffers, simply reduce the data in the buffer.
# If we have buffers, simply reduce the data in the buffer.
...
...
megatron/model/fused_softmax.py
View file @
7a77abd9
...
@@ -81,6 +81,37 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
...
@@ -81,6 +81,37 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
return
input_grads
,
None
,
None
return
input_grads
,
None
,
None
class
ScaledSoftmax
(
torch
.
autograd
.
Function
):
"""
Fused operation which performs following two operations in sequence
1. Scale the tensor.
2. Perform softmax.
"""
@
staticmethod
def
forward
(
ctx
,
inputs
,
scale
):
import
scaled_softmax_cuda
scale_t
=
torch
.
tensor
([
scale
])
softmax_results
=
scaled_softmax_cuda
.
forward
(
inputs
,
scale_t
[
0
]
)
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
return
softmax_results
@
staticmethod
def
backward
(
ctx
,
output_grads
):
import
scaled_softmax_cuda
softmax_results
,
scale_t
=
ctx
.
saved_tensors
input_grads
=
scaled_softmax_cuda
.
backward
(
output_grads
,
softmax_results
,
scale_t
[
0
]
)
return
input_grads
,
None
,
None
class
FusedScaleMaskSoftmax
(
nn
.
Module
):
class
FusedScaleMaskSoftmax
(
nn
.
Module
):
"""
"""
fused operation: scaling + mask + softmax
fused operation: scaling + mask + softmax
...
@@ -137,12 +168,11 @@ class FusedScaleMaskSoftmax(nn.Module):
...
@@ -137,12 +168,11 @@ class FusedScaleMaskSoftmax(nn.Module):
if
(
if
(
self
.
scaled_masked_softmax_fusion
# user want to fuse
self
.
scaled_masked_softmax_fusion
# user want to fuse
and
self
.
input_in_float16
# input must be fp16
and
self
.
input_in_float16
# input must be fp16
and
mask
is
not
None
# mask tensor must not be None
and
16
<
sk
<=
4096
# sk must be 16 ~ 2048
and
16
<
sk
<=
2048
# sk must be 16 ~ 2048
and
sq
%
4
==
0
# sq must be divisor of 4
and
sq
%
4
==
0
# sq must be divisor of 4
and
attn_batches
%
4
==
0
# np * b must be divisor of 4
and
attn_batches
%
4
==
0
# np * b must be divisor of 4
):
):
if
0
<=
sk
<=
2048
:
if
0
<=
sk
<=
4096
:
batch_per_block
=
self
.
get_batch_per_block
(
sq
,
sk
,
b
,
np
)
batch_per_block
=
self
.
get_batch_per_block
(
sq
,
sk
,
b
,
np
)
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
...
@@ -166,7 +196,10 @@ class FusedScaleMaskSoftmax(nn.Module):
...
@@ -166,7 +196,10 @@ class FusedScaleMaskSoftmax(nn.Module):
return
probs
.
view
(
b
,
np
,
sq
,
sk
)
return
probs
.
view
(
b
,
np
,
sq
,
sk
)
else
:
else
:
# input is 4D tensor (b, np, sq, sk)
# input is 4D tensor (b, np, sq, sk)
if
mask
is
not
None
:
return
ScaledMaskedSoftmax
.
apply
(
input
,
mask
,
scale
)
return
ScaledMaskedSoftmax
.
apply
(
input
,
mask
,
scale
)
else
:
return
ScaledSoftmax
.
apply
(
input
,
scale
)
def
forward_torch_softmax
(
self
,
input
,
mask
):
def
forward_torch_softmax
(
self
,
input
,
mask
):
if
self
.
input_in_float16
and
self
.
softmax_in_fp32
:
if
self
.
input_in_float16
and
self
.
softmax_in_fp32
:
...
...
megatron/model/vision/classification.py
0 → 100644
View file @
7a77abd9
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vision Transformer(VIT) model."""
import
torch
from
megatron
import
get_args
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.vision.vit_backbone
import
VitBackbone
,
VitMlpHead
from
megatron.model.vision.mit_backbone
import
mit_b3_avg
from
megatron.model.vision.utils
import
trunc_normal_
from
megatron.model.module
import
MegatronModule
class
VitClassificationModel
(
MegatronModule
):
"""Vision Transformer Model."""
def
__init__
(
self
,
num_classes
,
finetune
=
False
,
pre_process
=
True
,
post_process
=
True
):
super
(
VitClassificationModel
,
self
).
__init__
()
args
=
get_args
()
self
.
hidden_size
=
args
.
hidden_size
self
.
num_classes
=
num_classes
self
.
finetune
=
finetune
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
backbone
=
VitBackbone
(
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
,
single_token_output
=
True
)
if
self
.
post_process
:
if
not
self
.
finetune
:
self
.
head
=
VitMlpHead
(
self
.
hidden_size
,
self
.
num_classes
)
else
:
self
.
head
=
get_linear_layer
(
self
.
hidden_size
,
self
.
num_classes
,
torch
.
nn
.
init
.
zeros_
)
def
set_input_tensor
(
self
,
input_tensor
):
"""See megatron.model.transformer.set_input_tensor()"""
self
.
backbone
.
set_input_tensor
(
input_tensor
)
def
forward
(
self
,
input
):
hidden_states
=
self
.
backbone
(
input
)
if
self
.
post_process
:
hidden_states
=
self
.
head
(
hidden_states
)
return
hidden_states
megatron/model/vi
t_model
.py
→
megatron/model/vi
sion/vit_backbone
.py
View file @
7a77abd9
...
@@ -18,16 +18,19 @@
...
@@ -18,16 +18,19 @@
import
math
import
math
import
einops
import
einops
import
torch
import
torch
import
apex
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron.model
import
LayerNorm
from
megatron.model.transformer
import
ParallelTransformer
from
megatron.model.transformer
import
ParallelTransformer
from
megatron.model.utils
import
(
from
megatron.model.utils
import
(
get_linear_layer
,
get_linear_layer
,
init_method_normal
,
init_method_normal
,
scaled_init_method_normal
,
scaled_init_method_normal
,
)
)
from
.module
import
MegatronModule
from
megatron.model
.module
import
MegatronModule
CLASS_TOKEN_LENGTH
=
8
class
VitMlpHead
(
MegatronModule
):
class
VitMlpHead
(
MegatronModule
):
"""Pooler layer.
"""Pooler layer.
...
@@ -44,19 +47,26 @@ class VitMlpHead(MegatronModule):
...
@@ -44,19 +47,26 @@ class VitMlpHead(MegatronModule):
def
__init__
(
self
,
hidden_size
,
num_classes
):
def
__init__
(
self
,
hidden_size
,
num_classes
):
super
(
VitMlpHead
,
self
).
__init__
()
super
(
VitMlpHead
,
self
).
__init__
()
self
.
dense_in
=
torch
.
nn
.
Linear
(
hidden_size
,
hidden_size
)
self
.
dense_in
=
torch
.
nn
.
Linear
(
hidden_size
,
hidden_size
)
self
.
relu
=
torch
.
nn
.
ReLU
()
self
.
dense_out
=
torch
.
nn
.
Linear
(
hidden_size
,
num_classes
)
self
.
dense_out
=
torch
.
nn
.
Linear
(
hidden_size
,
num_classes
)
torch
.
nn
.
init
.
constant_
(
self
.
dense_out
.
bias
,
-
10
)
torch
.
nn
.
init
.
constant_
(
self
.
dense_out
.
bias
,
-
10
)
def
forward
(
self
,
hidden_states
,
sequence_index
=
0
):
def
forward
(
self
,
hidden_states
):
# hidden_states: [b,
s
, h]
# hidden_states: [b,
1
, h]
# sequence_index: index of the token to pool.
# sequence_index: index of the token to pool.
hidden_state
=
hidden_states
[:,
sequence_index
,
:]
dense_in_result
=
self
.
dense_in
(
hidden_states
)
dense_in_result
=
self
.
dense_in
(
hidden_state
)
tanh_result
=
torch
.
tanh
(
dense_in_result
)
tanh_result
=
torch
.
tanh
(
dense_in_result
)
dense_out_result
=
self
.
dense_out
(
tanh_result
)
dense_out_result
=
self
.
dense_out
(
tanh_result
)
return
dense_out_result
return
dense_out_result
def
isPerfectSquare
(
x
):
if
(
x
>=
0
):
sr
=
math
.
sqrt
(
x
)
return
(
int
(
sr
)
*
int
(
sr
)
==
x
)
return
False
def
twod_interpolate_position_embeddings_hook
(
def
twod_interpolate_position_embeddings_hook
(
state_dict
,
state_dict
,
prefix
,
prefix
,
...
@@ -68,66 +78,77 @@ def twod_interpolate_position_embeddings_hook(
...
@@ -68,66 +78,77 @@ def twod_interpolate_position_embeddings_hook(
):
):
args
=
get_args
()
args
=
get_args
()
num_patches_per_dim
=
args
.
img_
dim
//
args
.
patch_dim
num_patches_per_dim
_h
=
args
.
img_
h
//
args
.
patch_dim
num_patches
=
num_patches_per_dim
**
2
num_patches
_per_dim_w
=
args
.
img_w
//
args
.
patch_dim
seq_length
=
num_patches
+
1
num_patches
=
num_patches
_per_dim_h
*
num_patches_per_dim_w
hidden_size
=
args
.
hidden_size
hidden_size
=
args
.
hidden_size
key
=
prefix
+
"weight"
key
=
prefix
+
"weight"
# import pdb
# pdb.set_trace()
assert
key
in
state_dict
assert
key
in
state_dict
if
key
in
state_dict
:
if
key
in
state_dict
:
input_param
=
state_dict
[
key
]
input_param
=
state_dict
[
key
]
assert
input_param
.
shape
[
1
]
==
hidden_size
input_seq_len
=
input_param
.
shape
[
0
]
if
input_param
.
shape
[
0
]
!=
seq_length
:
assert
(
isPerfectSquare
(
input_seq_len
)
or
isPerfectSquare
(
input_seq_len
-
CLASS_TOKEN_LENGTH
))
input_has_class_token
=
not
isPerfectSquare
(
input_seq_len
)
num_tok_input
=
input_seq_len
-
CLASS_TOKEN_LENGTH
if
input_has_class_token
else
input_seq_len
num_tok_output
=
num_patches
output_has_class_token
=
args
.
class_token_present
# update input_param and load it to state_dict[key]
# update input_param and load it to state_dict[key]
if
input_has_class_token
:
input_param_tok
=
input_param
[:
CLASS_TOKEN_LENGTH
,
:]
input_param_grid
=
input_param
[
CLASS_TOKEN_LENGTH
:,
:]
else
:
input_param_tok
=
torch
.
zeros
(
CLASS_TOKEN_LENGTH
,
hidden_size
)
input_param_grid
=
input_param
num_tok_input
=
input_param
.
shape
[
0
]
-
1
assert
input_param
.
shape
[
1
]
==
hidden_size
num_tok_new
=
seq_length
-
1
input_param_tok
,
input_param_grid
=
(
if
num_tok_input
!=
num_tok_output
:
input_param
[:
1
,
:],
input_param
[
1
:,
:],
)
gs_input
=
int
(
math
.
sqrt
(
num_tok_input
))
gs_input
=
int
(
math
.
sqrt
(
num_tok_input
))
gs_new
=
int
(
math
.
sqrt
(
num_tok_ne
w
)
)
gs_new
=
(
num_patches_per_dim_h
,
num_patches_per_dim_
w
)
input_param_grid
=
input_param_grid
.
transpose
(
0
,
1
).
contiguous
()
input_param_grid
=
input_param_grid
.
transpose
(
0
,
1
).
contiguous
()
input_param_grid
=
input_param_grid
.
reshape
(
input_param_grid
=
input_param_grid
.
reshape
(
(
1
,
-
1
,
gs_input
,
gs_input
)
(
1
,
-
1
,
gs_input
,
gs_input
)
)
)
input_param_grid
=
input_param_grid
.
float
()
input_param_grid
=
input_param_grid
.
float
()
scale_factor
=
gs_new
/
gs_input
scale_factor
=
(
gs_new
[
0
]
/
gs_input
,
gs_new
[
1
]
/
gs_input
)
input_param_grid
=
F
.
interpolate
(
input_param_grid
=
F
.
interpolate
(
input_param_grid
,
scale_factor
=
scale_factor
,
mode
=
"bilinear"
input_param_grid
,
scale_factor
=
scale_factor
,
mode
=
"bilinear"
)
)
input_param_grid
=
input_param_grid
.
half
()
input_param_grid
=
input_param_grid
.
half
()
input_param_grid
=
input_param_grid
.
reshape
((
-
1
,
gs_new
*
gs_new
))
input_param_grid
=
input_param_grid
.
reshape
((
-
1
,
num_tok_output
))
input_param_grid
=
input_param_grid
.
transpose
(
0
,
1
).
contiguous
()
input_param_grid
=
input_param_grid
.
transpose
(
0
,
1
).
contiguous
()
assert
input_param_grid
.
shape
[
1
]
==
hidden_size
assert
input_param_grid
.
shape
[
1
]
==
hidden_size
input_param
=
torch
.
cat
((
input_param_tok
,
input_param_grid
),
dim
=
0
)
input_param
=
input_param_grid
assert
(
assert
(
input_param
.
shape
[
0
]
==
seq_length
input_param
.
shape
[
0
]
==
num_tok_output
and
input_param
.
shape
[
1
]
==
hidden_size
and
input_param
.
shape
[
1
]
==
hidden_size
)
)
if
output_has_class_token
:
input_param
=
torch
.
cat
((
input_param_tok
,
input_param
),
dim
=
0
)
state_dict
[
key
]
=
input_param
state_dict
[
key
]
=
input_param
class
Vit
Model
(
MegatronModule
):
class
Vit
Backbone
(
MegatronModule
):
"""Vision Transformer Model."""
"""Vision Transformer Model."""
def
__init__
(
self
,
def
__init__
(
self
,
num_classes
,
finetune
=
False
,
pre_process
=
True
,
pre_process
=
True
,
post_process
=
True
):
post_process
=
True
,
super
(
VitModel
,
self
).
__init__
(
share_word_embeddings
=
False
)
class_token
=
True
,
single_token_output
=
False
):
super
(
VitBackbone
,
self
).
__init__
(
share_word_embeddings
=
False
)
args
=
get_args
()
args
=
get_args
()
self
.
fp16_lm_cross_entropy
=
args
.
fp16_lm_cross_entropy
self
.
fp16_lm_cross_entropy
=
args
.
fp16_lm_cross_entropy
...
@@ -142,24 +163,32 @@ class VitModel(MegatronModule):
...
@@ -142,24 +163,32 @@ class VitModel(MegatronModule):
self
.
pre_process
=
pre_process
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
post_process
=
post_process
self
.
class_token
=
class_token
self
.
hidden_size
=
args
.
hidden_size
self
.
hidden_size
=
args
.
hidden_size
self
.
num_classes
=
num_classes
self
.
patch_dim
=
args
.
patch_dim
self
.
patch_dim
=
args
.
patch_dim
self
.
img_dim
=
args
.
img_dim
self
.
img_h
=
args
.
img_h
self
.
finetune
=
finetune
self
.
img_w
=
args
.
img_w
self
.
micro_batch_size
=
args
.
micro_batch_size
assert
self
.
img_dim
%
self
.
patch_dim
==
0
self
.
single_token_output
=
single_token_output
self
.
num_patches_per_dim
=
self
.
img_dim
//
self
.
patch_dim
self
.
num_patches
=
self
.
num_patches_per_dim
**
2
assert
self
.
img_h
%
self
.
patch_dim
==
0
self
.
seq_length
=
self
.
num_patches
+
1
assert
self
.
img_w
%
self
.
patch_dim
==
0
self
.
num_patches_per_dim_h
=
self
.
img_h
//
self
.
patch_dim
self
.
num_patches_per_dim_w
=
self
.
img_w
//
self
.
patch_dim
self
.
num_patches
=
self
.
num_patches_per_dim_h
*
self
.
num_patches_per_dim_w
self
.
seq_length
=
self
.
num_patches
+
(
CLASS_TOKEN_LENGTH
if
self
.
class_token
else
0
)
self
.
flatten_dim
=
self
.
patch_dim
*
self
.
patch_dim
*
args
.
num_channels
self
.
flatten_dim
=
self
.
patch_dim
*
self
.
patch_dim
*
args
.
num_channels
self
.
input_tensor
=
None
self
.
position_ids
=
None
if
self
.
pre_process
:
if
self
.
pre_process
:
# cls_token
# cls_token
if
self
.
class_token
:
self
.
cls_token
=
torch
.
nn
.
Parameter
(
self
.
cls_token
=
torch
.
nn
.
Parameter
(
torch
.
randn
(
1
,
1
,
self
.
hidden_size
)
torch
.
randn
(
1
,
CLASS_TOKEN_LENGTH
,
self
.
hidden_size
)
)
)
torch
.
nn
.
init
.
zeros_
(
self
.
cls_token
)
torch
.
nn
.
init
.
zeros_
(
self
.
cls_token
)
self
.
position_ids
=
torch
.
arange
(
self
.
seq_length
).
expand
(
1
,
-
1
).
cuda
()
# Linear encoder
# Linear encoder
self
.
linear_encoder
=
torch
.
nn
.
Linear
(
self
.
linear_encoder
=
torch
.
nn
.
Linear
(
...
@@ -173,8 +202,8 @@ class VitModel(MegatronModule):
...
@@ -173,8 +202,8 @@ class VitModel(MegatronModule):
init_method_normal
(
args
.
init_method_std
)(
init_method_normal
(
args
.
init_method_std
)(
self
.
position_embeddings
.
weight
self
.
position_embeddings
.
weight
)
)
self
.
position_ids
=
torch
.
arange
(
self
.
seq_length
).
expand
(
1
,
-
1
).
cuda
()
args
.
class_token_present
=
self
.
class_token
self
.
position_embeddings
.
_register_load_state_dict_pre_hook
(
self
.
position_embeddings
.
_register_load_state_dict_pre_hook
(
twod_interpolate_position_embeddings_hook
twod_interpolate_position_embeddings_hook
)
)
...
@@ -186,16 +215,7 @@ class VitModel(MegatronModule):
...
@@ -186,16 +215,7 @@ class VitModel(MegatronModule):
self
.
init_method
,
self
.
init_method
,
self
.
scaled_init_method
,
self
.
scaled_init_method
,
pre_process
=
self
.
pre_process
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
post_process
=
self
.
post_process
,
)
if
self
.
post_process
:
# MLP head
if
not
self
.
finetune
:
self
.
mlp_head
=
VitMlpHead
(
self
.
hidden_size
,
self
.
num_classes
)
else
:
self
.
class_head
=
get_linear_layer
(
self
.
hidden_size
,
num_classes
,
torch
.
nn
.
init
.
zeros_
)
)
def
set_input_tensor
(
self
,
input_tensor
):
def
set_input_tensor
(
self
,
input_tensor
):
...
@@ -214,21 +234,22 @@ class VitModel(MegatronModule):
...
@@ -214,21 +234,22 @@ class VitModel(MegatronModule):
assert
rearranged_input
.
dtype
==
torch
.
half
assert
rearranged_input
.
dtype
==
torch
.
half
encoder_output
=
self
.
linear_encoder
(
rearranged_input
)
encoder_output
=
self
.
linear_encoder
(
rearranged_input
)
concatenated_tokens
=
encoder_output
if
self
.
class_token
:
cls_tokens
=
self
.
cls_token
.
expand
(
encoder_output
.
shape
[
0
],
-
1
,
-
1
)
cls_tokens
=
self
.
cls_token
.
expand
(
encoder_output
.
shape
[
0
],
-
1
,
-
1
)
concatenated_tokens
=
torch
.
cat
((
cls_tokens
,
encoder_output
),
dim
=
1
)
concatenated_tokens
=
torch
.
cat
((
cls_tokens
,
encoder_output
),
dim
=
1
)
token_embeddings
=
concatenated_tokens
+
\
token_embeddings
=
concatenated_tokens
+
\
self
.
position_embeddings
(
self
.
position_ids
)
self
.
position_embeddings
(
self
.
position_ids
[:,
:
concatenated_tokens
.
shape
[
1
]]
)
hidden_states
=
self
.
embedding_dropout
(
token_embeddings
)
hidden_states
=
self
.
embedding_dropout
(
token_embeddings
)
else
:
else
:
hidden_states
=
input
hidden_states
=
input
hidden_states
=
self
.
transformer
(
hidden_states
,
None
)
hidden_states
=
self
.
transformer
(
hidden_states
,
None
)
if
self
.
post_process
:
if
self
.
single_token_output
:
if
not
self
.
finetune
:
hidden_states
=
hidden_states
[:,
0
,:]
hidden_states
=
self
.
mlp_head
(
hidden_states
)
else
:
hidden_states
=
self
.
class_head
(
hidden_states
[:,
0
,
:])
return
hidden_states
return
hidden_states
megatron/mpu/__init__.py
View file @
7a77abd9
...
@@ -38,6 +38,7 @@ from .initialize import is_pipeline_stage_before_split, is_pipeline_stage_after_
...
@@ -38,6 +38,7 @@ from .initialize import is_pipeline_stage_before_split, is_pipeline_stage_after_
from
.initialize
import
is_pipeline_stage_at_split
from
.initialize
import
is_pipeline_stage_at_split
from
.initialize
import
get_num_layers
from
.initialize
import
get_num_layers
from
.initialize
import
get_tensor_model_parallel_src_rank
from
.initialize
import
get_tensor_model_parallel_src_rank
from
.initialize
import
get_data_parallel_src_rank
from
.initialize
import
get_pipeline_model_parallel_first_rank
from
.initialize
import
get_pipeline_model_parallel_first_rank
from
.initialize
import
get_pipeline_model_parallel_last_rank
from
.initialize
import
get_pipeline_model_parallel_last_rank
from
.initialize
import
get_pipeline_model_parallel_next_rank
from
.initialize
import
get_pipeline_model_parallel_next_rank
...
...
megatron/mpu/initialize.py
View file @
7a77abd9
...
@@ -452,6 +452,15 @@ def get_tensor_model_parallel_src_rank():
...
@@ -452,6 +452,15 @@ def get_tensor_model_parallel_src_rank():
return
(
global_rank
//
local_world_size
)
*
local_world_size
return
(
global_rank
//
local_world_size
)
*
local_world_size
def
get_data_parallel_src_rank
():
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
global_rank
=
torch
.
distributed
.
get_rank
()
data_parallel_size
=
get_data_parallel_world_size
()
num_data_parallel_groups
=
torch
.
distributed
.
get_world_size
()
//
data_parallel_size
return
global_rank
%
num_data_parallel_groups
def
get_pipeline_model_parallel_first_rank
():
def
get_pipeline_model_parallel_first_rank
():
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
\
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
\
"Pipeline parallel group is not initialized"
"Pipeline parallel group is not initialized"
...
...
megatron/optimizer/__init__.py
View file @
7a77abd9
...
@@ -35,14 +35,14 @@ def _get_params_for_weight_decay_optimization(modules):
...
@@ -35,14 +35,14 @@ def _get_params_for_weight_decay_optimization(modules):
if
isinstance
(
module_
,
LayerNorm
):
if
isinstance
(
module_
,
LayerNorm
):
no_weight_decay_params
[
'params'
].
extend
(
no_weight_decay_params
[
'params'
].
extend
(
[
p
for
p
in
list
(
module_
.
_parameters
.
values
())
[
p
for
p
in
list
(
module_
.
_parameters
.
values
())
if
p
is
not
None
])
if
p
is
not
None
and
p
.
requires_grad
])
else
:
else
:
weight_decay_params
[
'params'
].
extend
(
weight_decay_params
[
'params'
].
extend
(
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
if
p
is
not
None
and
n
!=
'bias'
])
if
p
is
not
None
and
p
.
requires_grad
and
n
!=
'bias'
])
no_weight_decay_params
[
'params'
].
extend
(
no_weight_decay_params
[
'params'
].
extend
(
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
if
p
is
not
None
and
n
==
'bias'
])
if
p
is
not
None
and
p
.
requires_grad
and
n
==
'bias'
])
return
weight_decay_params
,
no_weight_decay_params
return
weight_decay_params
,
no_weight_decay_params
...
...
megatron/training.py
View file @
7a77abd9
...
@@ -285,7 +285,8 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
...
@@ -285,7 +285,8 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
args
.
accumulate_allreduce_grads_in_fp32
,
args
.
accumulate_allreduce_grads_in_fp32
,
args
.
use_contiguous_buffers_in_local_ddp
)
args
.
use_contiguous_buffers_in_local_ddp
)
for
model_module
in
model
]
for
model_module
in
model
]
for
model_module
in
model
:
model_module
.
broadcast_params
()
else
:
else
:
raise
NotImplementedError
(
'Unknown DDP implementation specified: '
raise
NotImplementedError
(
'Unknown DDP implementation specified: '
'{}. Exiting.'
.
format
(
args
.
DDP_impl
))
'{}. Exiting.'
.
format
(
args
.
DDP_impl
))
...
...
pretrain_vit.py
View file @
7a77abd9
...
@@ -21,7 +21,7 @@ from functools import partial
...
@@ -21,7 +21,7 @@ from functools import partial
from
megatron
import
get_args
,
get_timers
,
mpu
,
print_rank_0
from
megatron
import
get_args
,
get_timers
,
mpu
,
print_rank_0
from
megatron.data.vit_dataset
import
build_train_valid_datasets
from
megatron.data.vit_dataset
import
build_train_valid_datasets
from
megatron.model
import
ModelType
from
megatron.model
import
ModelType
from
megatron.model.vi
t_model
import
Vit
Model
from
megatron.model.vi
sion.classification
import
VitClassification
Model
from
megatron.training
import
pretrain
from
megatron.training
import
pretrain
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.utils
import
average_losses_across_data_parallel_group
...
@@ -31,7 +31,7 @@ def model_provider(pre_process=True, post_process=True):
...
@@ -31,7 +31,7 @@ def model_provider(pre_process=True, post_process=True):
print_rank_0
(
"building VIT model ..."
)
print_rank_0
(
"building VIT model ..."
)
args
=
get_args
()
args
=
get_args
()
model
=
VitModel
(
num_classes
=
args
.
num_classes
,
model
=
Vit
Classification
Model
(
num_classes
=
args
.
num_classes
,
pre_process
=
pre_process
,
pre_process
=
pre_process
,
post_process
=
post_process
)
post_process
=
post_process
)
return
model
return
model
...
@@ -82,7 +82,10 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
...
@@ -82,7 +82,10 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
print_rank_0
(
print_rank_0
(
"> building train, validation, and test datasets "
"for VIT ..."
"> building train, validation, and test datasets "
"for VIT ..."
)
)
train_ds
,
valid_ds
=
build_train_valid_datasets
(
data_path
=
args
.
data_path
)
train_ds
,
valid_ds
=
build_train_valid_datasets
(
data_path
=
args
.
data_path
,
image_size
=
(
args
.
img_h
,
args
.
img_w
)
)
print_rank_0
(
"> finished creating VIT datasets ..."
)
print_rank_0
(
"> finished creating VIT datasets ..."
)
return
train_ds
,
valid_ds
,
None
return
train_ds
,
valid_ds
,
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment