Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
feea48cd
Commit
feea48cd
authored
Aug 25, 2021
by
rprenger
Browse files
Merging with main
parents
8694c7b0
0be40526
Changes
52
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
968 additions
and
95 deletions
+968
-95
examples/sc21/run_figure_13.sh
examples/sc21/run_figure_13.sh
+46
-0
examples/sc21/run_figure_14.sh
examples/sc21/run_figure_14.sh
+47
-0
examples/sc21/run_figure_15.sh
examples/sc21/run_figure_15.sh
+47
-0
examples/sc21/run_figure_16.sh
examples/sc21/run_figure_16.sh
+43
-0
examples/sc21/run_figure_17.sh
examples/sc21/run_figure_17.sh
+54
-0
examples/sc21/run_figure_18.sh
examples/sc21/run_figure_18.sh
+54
-0
examples/sc21/run_table_1.sh
examples/sc21/run_table_1.sh
+145
-0
megatron/arguments.py
megatron/arguments.py
+42
-10
megatron/checkpointing.py
megatron/checkpointing.py
+35
-15
megatron/data/dataset_utils.py
megatron/data/dataset_utils.py
+2
-2
megatron/fused_kernels/scaled_masked_softmax.cpp
megatron/fused_kernels/scaled_masked_softmax.cpp
+21
-1
megatron/fused_kernels/scaled_masked_softmax.h
megatron/fused_kernels/scaled_masked_softmax.h
+16
-3
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
+5
-0
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
+4
-2
megatron/fused_kernels/tests/__init__.py
megatron/fused_kernels/tests/__init__.py
+0
-0
megatron/fused_kernels/tests/test_fused_kernels.py
megatron/fused_kernels/tests/test_fused_kernels.py
+300
-0
megatron/initialize.py
megatron/initialize.py
+0
-4
megatron/model/fused_softmax.py
megatron/model/fused_softmax.py
+73
-45
megatron/model/transformer.py
megatron/model/transformer.py
+30
-13
megatron/mpu/initialize.py
megatron/mpu/initialize.py
+4
-0
No files found.
examples/sc21/run_figure_13.sh
0 → 100755
View file @
feea48cd
#!/bin/bash
# ================================
# Choose the case to run.
# ================================
# Pipeline-parallel size options = [2, 4, 8, 16, 32].
PP
=
2
# Batch size (global batch size) options = [32, 128].
GBS
=
32
# Set pipeline-parallel and tensor-parallel size options.
TP
=
$((
64
/
PP
))
# Other params.
MBS
=
1
NLS
=
32
HS
=
20480
NAH
=
128
DDP
=
local
MEGATRON_EXTRA_PARAMS
=
"--activations-checkpoint-method uniform "
NNODES
=
8
# Name of the job.
export
JOB_NAME
=
results_figure_13_pipeline_parallel_size_
${
PP
}
_tensor_parallel_size_
${
TP
}
_batch_size_
${
GBS
}
# Import the configs.
.
`
pwd
`
/CONFIG.sh
# Submit the job.
.
`
pwd
`
/SBATCH.sh
exit
0
examples/sc21/run_figure_14.sh
0 → 100755
View file @
feea48cd
#!/bin/bash
# ================================
# Choose the case to run.
# ================================
# Pipeline-parallel size options = [2, 4, 8, 16, 32].
PP
=
2
# Batch size (global batch size) options = [32, 512].
GBS
=
32
# Set pipeline-parallel and data-parallel size options.
DP
=
$((
64
/
PP
))
# Other params.
TP
=
1
MBS
=
1
NLS
=
32
HS
=
3840
NAH
=
32
DDP
=
local
MEGATRON_EXTRA_PARAMS
=
"--activations-checkpoint-method uniform "
NNODES
=
8
# Name of the job.
export
JOB_NAME
=
results_figure_14_pipeline_parallel_size_
${
PP
}
_data_parallel_size_
${
DP
}
_batch_size_
${
GBS
}
# Import the configs.
.
`
pwd
`
/CONFIG.sh
# Submit the job.
.
`
pwd
`
/SBATCH.sh
exit
0
examples/sc21/run_figure_15.sh
0 → 100755
View file @
feea48cd
#!/bin/bash
# ================================
# Choose the case to run.
# ================================
# Tensor-parallel size options = [2, 4, 8, 16, 32].
TP
=
2
# Batch size (global batch size) options = [32, 128, 512].
GBS
=
32
# Set tensor-parallel and data-parallel size options.
DP
=
$((
64
/
TP
))
# Other params.
PP
=
1
MBS
=
1
NLS
=
32
HS
=
3840
NAH
=
32
DDP
=
local
MEGATRON_EXTRA_PARAMS
=
"--activations-checkpoint-method uniform "
NNODES
=
8
# Name of the job.
export
JOB_NAME
=
results_figure_15_tensor_parallel_size_
${
TP
}
_data_parallel_size_
${
DP
}
_batch_size_
${
GBS
}
# Import the configs.
.
`
pwd
`
/CONFIG.sh
# Submit the job.
.
`
pwd
`
/SBATCH.sh
exit
0
examples/sc21/run_figure_16.sh
0 → 100755
View file @
feea48cd
#!/bin/bash
# ================================
# Choose the case to run.
# ================================
# Microbatch size options = [1, 2, 4, 8].
MBS
=
1
# Batch size (global batch size) options = [128, 512].
GBS
=
128
# Other params.
TP
=
8
PP
=
8
NLS
=
32
HS
=
15360
NAH
=
128
DDP
=
local
MEGATRON_EXTRA_PARAMS
=
"--activations-checkpoint-method uniform "
NNODES
=
8
# Name of the job.
export
JOB_NAME
=
results_figure_16_microbatch_size_
${
MBS
}
_batch_size_
${
GBS
}
# Import the configs.
.
`
pwd
`
/CONFIG.sh
# Submit the job.
.
`
pwd
`
/SBATCH.sh
exit
0
examples/sc21/run_figure_17.sh
0 → 100755
View file @
feea48cd
#!/bin/bash
# ================================
# Choose the case to run.
# ================================
# Activation recomputation options = [YES, NO].
ACTIVATION_RECOMPUTATION
=
YES
# Batch size (global batch size) options = [1, 2, 4, ..., 256].
GBS
=
1
# Set activation recomputation.
if
[
${
ACTIVATION_RECOMPUTATION
}
==
"YES"
]
;
then
MEGATRON_EXTRA_PARAMS
=
"--activations-checkpoint-method uniform "
elif
[
${
ACTIVATION_RECOMPUTATION
}
==
"NO"
]
;
then
MEGATRON_EXTRA_PARAMS
=
""
else
echo
"Invalid configuration"
exit
1
fi
# Other params.
TP
=
8
PP
=
16
MBS
=
1
NLS
=
80
HS
=
12288
NAH
=
96
DDP
=
local
NNODES
=
16
# Name of the job.
export
JOB_NAME
=
results_figure_17_activation_recomputation_
${
ACTIVATION_RECOMPUTATION
}
_batch_size_
${
GBS
}
# Import the configs.
.
`
pwd
`
/CONFIG.sh
# Submit the job.
.
`
pwd
`
/SBATCH.sh
exit
0
examples/sc21/run_figure_18.sh
0 → 100755
View file @
feea48cd
#!/bin/bash
# ================================
# Choose the case to run.
# ================================
# Scatter-gather communication optimization options = [YES, NO].
SCATTER_GATHER
=
YES
# Batch size (global batch size) options = [12, 24, 36, ..., 60].
GBS
=
12
# Set scatter-gather communication optimization options.
if
[
${
SCATTER_GATHER
}
==
"YES"
]
;
then
MEGATRON_EXTRA_PARAMS
=
"--activations-checkpoint-method uniform --num-layers-per-virtual-pipeline-stage 2 "
elif
[
${
SCATTER_GATHER
}
==
"NO"
]
;
then
MEGATRON_EXTRA_PARAMS
=
"--activations-checkpoint-method uniform --num-layers-per-virtual-pipeline-stage 2 --no-scatter-gather-tensors-in-pipeline "
else
echo
"Invalid configuration"
exit
1
fi
# Other params.
TP
=
8
PP
=
12
MBS
=
1
NLS
=
96
HS
=
12288
NAH
=
96
DDP
=
local
NNODES
=
12
# Name of the job.
export
JOB_NAME
=
results_figure_18_scatter_gather_
${
SCATTER_GATHER
}
_batch_size_
${
GBS
}
# Import the configs.
.
`
pwd
`
/CONFIG.sh
# Submit the job.
.
`
pwd
`
/SBATCH.sh
exit
0
examples/sc21/run_table_1.sh
0 → 100755
View file @
feea48cd
#!/bin/bash
# ================================
# Choose the case to run.
# ================================
# model size options = [1.7B, 3.6B, 7.5B, 18B, 39B, 76B, 145B, 310B, 530B, 1T]
MODEL_SIZE
=
1.7B
if
[
${
MODEL_SIZE
}
==
"1.7B"
]
;
then
TP
=
1
PP
=
1
MBS
=
16
GBS
=
512
NLS
=
24
HS
=
2304
NAH
=
24
DDP
=
torch
NNODES
=
4
MEGATRON_EXTRA_PARAMS
=
"--activations-checkpoint-method uniform "
elif
[
${
MODEL_SIZE
}
==
"3.6B"
]
;
then
TP
=
2
PP
=
1
MBS
=
16
GBS
=
512
NLS
=
30
HS
=
3072
NAH
=
32
DDP
=
torch
NNODES
=
8
MEGATRON_EXTRA_PARAMS
=
"--activations-checkpoint-method uniform "
elif
[
${
MODEL_SIZE
}
==
"7.5B"
]
;
then
TP
=
4
PP
=
1
MBS
=
16
GBS
=
512
NLS
=
36
HS
=
4096
NAH
=
32
DDP
=
torch
NNODES
=
16
MEGATRON_EXTRA_PARAMS
=
"--activations-checkpoint-method uniform "
elif
[
${
MODEL_SIZE
}
==
"18B"
]
;
then
TP
=
8
PP
=
1
MBS
=
8
GBS
=
1024
NLS
=
40
HS
=
6144
NAH
=
48
DDP
=
torch
NNODES
=
32
MEGATRON_EXTRA_PARAMS
=
"--activations-checkpoint-method uniform "
elif
[
${
MODEL_SIZE
}
==
"39B"
]
;
then
TP
=
8
PP
=
2
MBS
=
4
GBS
=
1536
NLS
=
48
HS
=
8192
NAH
=
64
DDP
=
local
NNODES
=
64
MEGATRON_EXTRA_PARAMS
=
"--activations-checkpoint-method uniform "
elif
[
${
MODEL_SIZE
}
==
"76B"
]
;
then
TP
=
8
PP
=
4
MBS
=
2
GBS
=
1792
NLS
=
60
HS
=
10240
NAH
=
80
DDP
=
local
NNODES
=
128
MEGATRON_EXTRA_PARAMS
=
"--activations-checkpoint-method uniform --num-layers-per-virtual-pipeline-stage 5"
elif
[
${
MODEL_SIZE
}
==
"145B"
]
;
then
TP
=
8
PP
=
8
MBS
=
2
GBS
=
2304
NLS
=
80
HS
=
12288
NAH
=
96
DDP
=
local
NNODES
=
192
MEGATRON_EXTRA_PARAMS
=
"--activations-checkpoint-method uniform --num-layers-per-virtual-pipeline-stage 5 "
elif
[
${
MODEL_SIZE
}
==
"310B"
]
;
then
TP
=
8
PP
=
16
MBS
=
1
GBS
=
2160
NLS
=
96
HS
=
16384
NAH
=
128
DDP
=
local
NNODES
=
240
MEGATRON_EXTRA_PARAMS
=
"--activations-checkpoint-method uniform --num-layers-per-virtual-pipeline-stage 3 "
elif
[
${
MODEL_SIZE
}
==
"530B"
]
;
then
TP
=
8
PP
=
35
MBS
=
1
GBS
=
2520
NLS
=
105
HS
=
20480
NAH
=
128
DDP
=
local
NNODES
=
315
MEGATRON_EXTRA_PARAMS
=
"--activations-checkpoint-method uniform --num-layers-per-virtual-pipeline-stage 1 "
elif
[
${
MODEL_SIZE
}
==
"1T"
]
;
then
TP
=
8
PP
=
64
MBS
=
1
GBS
=
3072
NLS
=
128
HS
=
25600
NAH
=
160
DDP
=
local
NNODES
=
384
MEGATRON_EXTRA_PARAMS
=
"--activations-checkpoint-method uniform "
else
echo
"Invalid configuration"
exit
1
fi
# Name of the job
export
JOB_NAME
=
results_table_1_model_size_
${
MODEL_SIZE
}
# Import the configs.
.
`
pwd
`
/CONFIG.sh
# Submit the job.
.
`
pwd
`
/SBATCH.sh
exit
0
megatron/arguments.py
View file @
feea48cd
...
@@ -91,6 +91,13 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -91,6 +91,13 @@ def parse_args(extra_args_provider=None, defaults={},
assert
args
.
model_parallel_size
is
None
,
'--model-parallel-size is no '
\
assert
args
.
model_parallel_size
is
None
,
'--model-parallel-size is no '
\
'longer valid, use --tensor-model-parallel-size instead'
'longer valid, use --tensor-model-parallel-size instead'
del
args
.
model_parallel_size
del
args
.
model_parallel_size
if
args
.
checkpoint_activations
:
args
.
activations_checkpoint_method
=
'uniform'
if
args
.
rank
==
0
:
print
(
'--checkpoint-activations is no longer valid, '
'use --activation-checkpoint-method instead. '
'Defaulting to activation-checkpoint-method=uniform.'
)
del
args
.
checkpoint_activations
# Set input defaults.
# Set input defaults.
for
key
in
defaults
:
for
key
in
defaults
:
...
@@ -148,11 +155,15 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -148,11 +155,15 @@ def parse_args(extra_args_provider=None, defaults={},
print
(
'using {} for parameters ...'
.
format
(
args
.
params_dtype
),
print
(
'using {} for parameters ...'
.
format
(
args
.
params_dtype
),
flush
=
True
)
flush
=
True
)
# If we do accumulation and all-reduces in fp32, we need to have
# If we do accumulation and all-reduces in fp32, we need to have
local DDP
#
local DDP
and we should
set th
e use-contiguous-buffers-in-
ddp
.
# and we should
make sur
e use-contiguous-buffers-in-
local-ddp is not off
.
if
args
.
accumulate_allreduce_grads_in_fp32
:
if
args
.
accumulate_allreduce_grads_in_fp32
:
assert
args
.
DDP_impl
==
'local'
assert
args
.
DDP_impl
==
'local'
args
.
use_contiguous_buffers_in_ddp
=
True
assert
args
.
use_contiguous_buffers_in_local_ddp
# For torch DDP, we do not use contiguous buffer
if
args
.
DDP_impl
==
'torch'
:
args
.
use_contiguous_buffers_in_local_ddp
=
False
if
args
.
dataloader_type
is
None
:
if
args
.
dataloader_type
is
None
:
args
.
dataloader_type
=
'single'
args
.
dataloader_type
=
'single'
...
@@ -229,9 +240,9 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -229,9 +240,9 @@ def parse_args(extra_args_provider=None, defaults={},
'residual connection in fp32 only supported when using fp16 or bf16.'
'residual connection in fp32 only supported when using fp16 or bf16.'
# Activation checkpointing.
# Activation checkpointing.
if
args
.
distribute_checkpointed_activations
:
if
args
.
distribute_checkpointed_activations
:
assert
args
.
checkpoint_activati
on
s
,
\
assert
args
.
activations_checkpoint_method
is
not
N
on
e
,
\
'for distribute-checkpointed-activations to work you '
\
'for distribute-checkpointed-activations to work you '
\
'need to
enable
checkpoint-activation
s
'
'need to
use a valid
checkpoint-activation
method (
\'
uniform
\'
or
\'
block
\'
)
'
_print_args
(
args
)
_print_args
(
args
)
return
args
return
args
...
@@ -328,6 +339,9 @@ def _add_logging_args(parser):
...
@@ -328,6 +339,9 @@ def _add_logging_args(parser):
action
=
'store_true'
,
action
=
'store_true'
,
help
=
'If set, write validation perplexity to '
help
=
'If set, write validation perplexity to '
'tensorboard.'
)
'tensorboard.'
)
group
.
add_argument
(
'--log-memory-to-tensorboard'
,
action
=
'store_true'
,
help
=
'Enable memory logging to tensorboard.'
)
return
parser
return
parser
...
@@ -394,8 +408,20 @@ def _add_training_args(parser):
...
@@ -394,8 +408,20 @@ def _add_training_args(parser):
action
=
'store_true'
,
action
=
'store_true'
,
help
=
'If set, distribute checkpointed activations '
help
=
'If set, distribute checkpointed activations '
'across model parallel group.'
)
'across model parallel group.'
)
group
.
add_argument
(
'--checkpoint-num-layers'
,
type
=
int
,
default
=
1
,
group
.
add_argument
(
'--activations-checkpoint-method'
,
type
=
str
,
default
=
None
,
help
=
'chunk size (number of layers) for checkpointing.'
)
choices
=
[
'uniform'
,
'block'
],
help
=
'1) uniform: uniformly divide the total number of '
'Transformer layers and checkpoint the input activation of '
'each divided chunk, '
'2) checkpoint the input activations of only a set number of '
'individual Transformer layers per pipeline stage and do the '
'rest without any checkpointing'
'default) do not apply activations checkpoint to any layers'
)
group
.
add_argument
(
'--activations-checkpoint-num-layers'
,
type
=
int
,
default
=
1
,
help
=
'1) uniform: the number of Transformer layers in each '
'uniformly divided checkpoint unit, '
'2) block: the number of individual Transformer layers '
'to checkpoint within each pipeline stage.'
)
group
.
add_argument
(
'--train-iters'
,
type
=
int
,
default
=
None
,
group
.
add_argument
(
'--train-iters'
,
type
=
int
,
default
=
None
,
help
=
'Total number of iterations to train over all '
help
=
'Total number of iterations to train over all '
'training runs. Note that either train-iters or '
'training runs. Note that either train-iters or '
...
@@ -576,9 +602,10 @@ def _add_distributed_args(parser):
...
@@ -576,9 +602,10 @@ def _add_distributed_args(parser):
choices
=
[
'local'
,
'torch'
],
choices
=
[
'local'
,
'torch'
],
help
=
'which DistributedDataParallel implementation '
help
=
'which DistributedDataParallel implementation '
'to use.'
)
'to use.'
)
group
.
add_argument
(
'--use-contiguous-buffers-in-ddp'
,
action
=
'store_true'
,
group
.
add_argument
(
'--no-contiguous-buffers-in-local-ddp'
,
help
=
'If set, use contiguous buffer in DDP. Note that '
action
=
'store_false'
,
help
=
'If set, dont use '
'this option only works woth local DDP.'
)
'contiguous buffer in local DDP.'
,
dest
=
'use_contiguous_buffers_in_local_ddp'
)
group
.
add_argument
(
'--no-scatter-gather-tensors-in-pipeline'
,
action
=
'store_false'
,
group
.
add_argument
(
'--no-scatter-gather-tensors-in-pipeline'
,
action
=
'store_false'
,
help
=
'Use scatter/gather to optimize communication of tensors in pipeline'
,
help
=
'Use scatter/gather to optimize communication of tensors in pipeline'
,
dest
=
'scatter_gather_tensors_in_pipeline'
)
dest
=
'scatter_gather_tensors_in_pipeline'
)
...
@@ -593,6 +620,11 @@ def _add_distributed_args(parser):
...
@@ -593,6 +620,11 @@ def _add_distributed_args(parser):
group
.
add_argument
(
'--use-cpu-initialization'
,
action
=
'store_true'
,
group
.
add_argument
(
'--use-cpu-initialization'
,
action
=
'store_true'
,
default
=
None
,
help
=
'If set, affine parallel weights '
default
=
None
,
help
=
'If set, affine parallel weights '
'initialization uses CPU'
)
'initialization uses CPU'
)
group
.
add_argument
(
'--empty-unused-memory-level'
,
default
=
0
,
type
=
int
,
choices
=
[
0
,
1
,
2
],
help
=
'Call torch.cuda.empty_cache() each iteration '
'(training and eval), to reduce fragmentation.'
'0=off, 1=moderate, 2=aggressive.'
)
return
parser
return
parser
...
...
megatron/checkpointing.py
View file @
feea48cd
...
@@ -106,6 +106,40 @@ def get_checkpoint_tracker_filename(checkpoints_path):
...
@@ -106,6 +106,40 @@ def get_checkpoint_tracker_filename(checkpoints_path):
return
os
.
path
.
join
(
checkpoints_path
,
'latest_checkpointed_iteration.txt'
)
return
os
.
path
.
join
(
checkpoints_path
,
'latest_checkpointed_iteration.txt'
)
def
read_metadata
(
tracker_filename
):
# Read the tracker file and either set the iteration or
# mark it as a release checkpoint.
iteration
=
0
release
=
False
with
open
(
tracker_filename
,
'r'
)
as
f
:
metastring
=
f
.
read
().
strip
()
try
:
iteration
=
int
(
metastring
)
except
ValueError
:
release
=
metastring
==
'release'
if
not
release
:
print_rank_0
(
'ERROR: Invalid metadata file {}. Exiting'
.
format
(
tracker_filename
))
sys
.
exit
()
assert
iteration
>
0
or
release
,
'error parsing metadata file {}'
.
format
(
tracker_filename
)
# Get the max iteration retrieved across the ranks.
iters_cuda
=
torch
.
cuda
.
LongTensor
([
iteration
])
torch
.
distributed
.
all_reduce
(
iters_cuda
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
)
max_iter
=
iters_cuda
[
0
].
item
()
# We should now have all the same iteration.
# If not, print a warning and chose the maximum
# iteration across all ranks.
if
iteration
!=
max_iter
:
print
(
'WARNING: on rank {} found iteration {} in the '
'metadata while max iteration across the ranks '
'is {}, replacing it with max iteration.'
.
format
(
rank
,
iteration
,
max_iter
),
flush
=
True
)
return
max_iter
,
release
def
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
):
def
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
):
"""Save a model checkpoint."""
"""Save a model checkpoint."""
args
=
get_args
()
args
=
get_args
()
...
@@ -260,21 +294,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
...
@@ -260,21 +294,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
# Otherwise, read the tracker file and either set the iteration or
# Otherwise, read the tracker file and either set the iteration or
# mark it as a release checkpoint.
# mark it as a release checkpoint.
iteration
=
0
iteration
,
release
=
read_metadata
(
tracker_filename
)
release
=
False
with
open
(
tracker_filename
,
'r'
)
as
f
:
metastring
=
f
.
read
().
strip
()
try
:
iteration
=
int
(
metastring
)
except
ValueError
:
release
=
metastring
==
'release'
if
not
release
:
print_rank_0
(
'ERROR: Invalid metadata file {}. Exiting'
.
format
(
tracker_filename
))
sys
.
exit
()
assert
iteration
>
0
or
release
,
'error parsing metadata file {}'
.
format
(
tracker_filename
)
# Checkpoint.
# Checkpoint.
checkpoint_name
=
get_checkpoint_name
(
load_dir
,
iteration
,
release
)
checkpoint_name
=
get_checkpoint_name
(
load_dir
,
iteration
,
release
)
...
...
megatron/data/dataset_utils.py
View file @
feea48cd
...
@@ -674,7 +674,7 @@ def get_samples_mapping(indexed_dataset,
...
@@ -674,7 +674,7 @@ def get_samples_mapping(indexed_dataset,
# Build samples mapping
# Build samples mapping
verbose
=
torch
.
distributed
.
get_rank
()
==
0
verbose
=
torch
.
distributed
.
get_rank
()
==
0
start_time
=
time
.
time
()
start_time
=
time
.
time
()
print_rank_0
(
' > building sa
p
mles index mapping for {} ...'
.
format
(
print_rank_0
(
' > building sam
p
les index mapping for {} ...'
.
format
(
name
))
name
))
# First compile and then import.
# First compile and then import.
from
megatron.data
import
helpers
from
megatron.data
import
helpers
...
@@ -688,7 +688,7 @@ def get_samples_mapping(indexed_dataset,
...
@@ -688,7 +688,7 @@ def get_samples_mapping(indexed_dataset,
seed
,
seed
,
verbose
,
verbose
,
2
if
binary_head
else
1
)
2
if
binary_head
else
1
)
print_rank_0
(
' > done building sa
p
mles index maping'
)
print_rank_0
(
' > done building sam
p
les index maping'
)
np
.
save
(
indexmap_filename
,
samples_mapping
,
allow_pickle
=
True
)
np
.
save
(
indexmap_filename
,
samples_mapping
,
allow_pickle
=
True
)
print_rank_0
(
' > saved the index mapping in {}'
.
format
(
print_rank_0
(
' > saved the index mapping in {}'
.
format
(
indexmap_filename
))
indexmap_filename
))
...
...
megatron/fused_kernels/scaled_masked_softmax.cpp
View file @
feea48cd
...
@@ -32,6 +32,12 @@ torch::Tensor bwd_cuda(
...
@@ -32,6 +32,12 @@ torch::Tensor bwd_cuda(
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
softmax_results
,
float
scale_factor
);
float
scale_factor
);
int
get_batch_per_block_cuda
(
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
);
torch
::
Tensor
fwd
(
torch
::
Tensor
fwd
(
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
mask
,
torch
::
Tensor
const
&
mask
,
...
@@ -63,6 +69,14 @@ torch::Tensor bwd(
...
@@ -63,6 +69,14 @@ torch::Tensor bwd(
return
bwd_cuda
(
output_grads
,
softmax_results
,
scale_factor
);
return
bwd_cuda
(
output_grads
,
softmax_results
,
scale_factor
);
}
}
int
get_batch_per_block
(
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
)
{
return
get_batch_per_block_cuda
(
query_seq_len
,
key_seq_len
,
batches
,
attn_heads
);
}
}
// end namespace scaled_masked_softmax
}
// end namespace scaled_masked_softmax
}
// end namespace fused_softmax
}
// end namespace fused_softmax
}
// end namespace multihead_attn
}
// end namespace multihead_attn
...
@@ -71,7 +85,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -71,7 +85,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"forward"
,
m
.
def
(
"forward"
,
&
multihead_attn
::
fused_softmax
::
scaled_masked_softmax
::
fwd
,
&
multihead_attn
::
fused_softmax
::
scaled_masked_softmax
::
fwd
,
"Self Multihead Attention scaled, time masked softmax -- Forward."
);
"Self Multihead Attention scaled, time masked softmax -- Forward."
);
m
.
def
(
"backward"
,
m
.
def
(
"backward"
,
&
multihead_attn
::
fused_softmax
::
scaled_masked_softmax
::
bwd
,
&
multihead_attn
::
fused_softmax
::
scaled_masked_softmax
::
bwd
,
"Self Multihead Attention scaled, time masked softmax -- Backward."
);
"Self Multihead Attention scaled, time masked softmax -- Backward."
);
m
.
def
(
"get_batch_per_block"
,
&
multihead_attn
::
fused_softmax
::
scaled_masked_softmax
::
get_batch_per_block
,
"Return Batch per block size."
);
}
}
megatron/fused_kernels/scaled_masked_softmax.h
View file @
feea48cd
...
@@ -111,7 +111,7 @@ __global__ void scaled_masked_softmax_warp_forward(
...
@@ -111,7 +111,7 @@ __global__ void scaled_masked_softmax_warp_forward(
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
ELEMENTS_PER_LDG_STG
=
4
;
constexpr
int
ELEMENTS_PER_LDG_STG
=
(
WARP_ITERATIONS
<
4
)
?
1
:
4
;
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
// gridDim/blockIdx = (seq_len, attn_heads, batches)
...
@@ -230,7 +230,7 @@ __global__ void scaled_masked_softmax_warp_backward(
...
@@ -230,7 +230,7 @@ __global__ void scaled_masked_softmax_warp_backward(
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
ELEMENTS_PER_LDG_STG
=
4
;
constexpr
int
ELEMENTS_PER_LDG_STG
=
(
WARP_ITERATIONS
<
4
)
?
1
:
4
;
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
// gridDim/blockIdx = (seq_len, attn_heads, batches)
...
@@ -310,9 +310,22 @@ __global__ void scaled_masked_softmax_warp_backward(
...
@@ -310,9 +310,22 @@ __global__ void scaled_masked_softmax_warp_backward(
}
}
}
}
}
}
}
// end of anonymous namespace
}
// end of anonymous namespace
int
get_batch_per_block
(
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
){
int
log2_elements
=
log2_ceil
(
key_seq_len
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
return
batches_per_block
;
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
void
dispatch_scaled_masked_softmax_forward
(
void
dispatch_scaled_masked_softmax_forward
(
output_t
*
dst
,
output_t
*
dst
,
...
...
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
View file @
feea48cd
...
@@ -28,6 +28,11 @@ namespace multihead_attn {
...
@@ -28,6 +28,11 @@ namespace multihead_attn {
namespace
fused_softmax
{
namespace
fused_softmax
{
namespace
scaled_masked_softmax
{
namespace
scaled_masked_softmax
{
int
get_batch_per_block_cuda
(
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
){
return
get_batch_per_block
(
query_seq_len
,
key_seq_len
,
batches
,
attn_heads
);
}
torch
::
Tensor
fwd_cuda
(
torch
::
Tensor
fwd_cuda
(
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
mask
,
torch
::
Tensor
const
&
mask
,
...
...
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
View file @
feea48cd
...
@@ -125,7 +125,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
...
@@ -125,7 +125,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
ELEMENTS_PER_LDG_STG
=
4
;
constexpr
int
ELEMENTS_PER_LDG_STG
=
(
WARP_ITERATIONS
<
4
)
?
1
:
4
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
y
+
threadIdx
.
y
)
*
gridDim
.
x
*
WARP_BATCH
+
blockIdx
.
x
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
y
+
threadIdx
.
y
)
*
gridDim
.
x
*
WARP_BATCH
+
blockIdx
.
x
;
int
local_seq
=
blockIdx
.
x
+
1
;
int
local_seq
=
blockIdx
.
x
+
1
;
...
@@ -245,7 +245,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
...
@@ -245,7 +245,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
ELEMENTS_PER_LDG_STG
=
4
;
constexpr
int
ELEMENTS_PER_LDG_STG
=
(
WARP_ITERATIONS
<
4
)
?
1
:
4
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
y
+
threadIdx
.
y
)
*
gridDim
.
x
*
WARP_BATCH
+
blockIdx
.
x
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
y
+
threadIdx
.
y
)
*
gridDim
.
x
*
WARP_BATCH
+
blockIdx
.
x
;
int
local_seq
=
blockIdx
.
x
+
1
;
int
local_seq
=
blockIdx
.
x
+
1
;
...
@@ -361,6 +361,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
...
@@ -361,6 +361,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
TORCH_INTERNAL_ASSERT
(
attn_batches
%
batches_per_block
==
0
);
TORCH_INTERNAL_ASSERT
(
attn_batches
%
batches_per_block
==
0
);
int
blocks_per_seq
=
attn_batches
/
batches_per_block
;
int
blocks_per_seq
=
attn_batches
/
batches_per_block
;
dim3
blocks
(
seq_len
,
blocks_per_seq
,
1
);
dim3
blocks
(
seq_len
,
blocks_per_seq
,
1
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
...
@@ -451,6 +452,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
...
@@ -451,6 +452,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
TORCH_INTERNAL_ASSERT
(
attn_batches
%
batches_per_block
==
0
);
TORCH_INTERNAL_ASSERT
(
attn_batches
%
batches_per_block
==
0
);
int
blocks_per_seq
=
attn_batches
/
batches_per_block
;
int
blocks_per_seq
=
attn_batches
/
batches_per_block
;
dim3
blocks
(
seq_len
,
blocks_per_seq
,
1
);
dim3
blocks
(
seq_len
,
blocks_per_seq
,
1
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
...
...
megatron/fused_kernels/tests/__init__.py
0 → 100644
View file @
feea48cd
megatron/fused_kernels/tests/test_fused_kernels.py
0 → 100644
View file @
feea48cd
import
math
import
torch
from
torch.nn
import
LayerNorm
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.fused_layer_norm
import
MixedFusedLayerNorm
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.model.utils
import
attention_mask_func
def
test_load_fused_kernels
():
try
:
import
fused_mix_prec_layer_norm_cuda
import
scaled_masked_softmax_cuda
import
scaled_upper_triang_masked_softmax_cuda
import
torch
print
(
"[Success] load_fused_kernels"
)
except
ImportError
as
e
:
print
(
"[Fail] load_fused_kernels"
)
raise
e
def
test_fused_softmax
():
bert
=
BertModel
.
from_pretrained
(
"bert-base-cased"
).
cuda
().
half
()
tokenizer
=
BertTokenizer
.
from_pretrained
(
"bert-base-cased"
)
test_text
=
(
"Hello. How are you? I am fine thank you and you? yes Good. "
"hi hi hi hi hi hi hi hi hi hi hi hi hi"
# 32
)
tokens
=
tokenizer
(
[
test_text
]
*
4
,
return_tensors
=
"pt"
,
)
embedding_output
=
bert
.
embeddings
(
input_ids
=
tokens
[
"input_ids"
].
cuda
(),
position_ids
=
None
,
token_type_ids
=
tokens
[
"token_type_ids"
].
cuda
(),
inputs_embeds
=
None
,
past_key_values_length
=
0
,
)
# (bsz, 1, 1, seq_len)
mask
=
bert
.
get_extended_attention_mask
(
attention_mask
=
tokens
[
"attention_mask"
].
cuda
(),
input_shape
=
tokens
[
"input_ids"
].
shape
,
device
=
bert
.
device
,
)
# (bsz, 1, seq_len, seq_len)
mask
=
mask
.
repeat
(
1
,
1
,
mask
.
size
()[
-
1
],
1
)
attention
=
bert
.
encoder
.
layer
[
0
].
attention
.
self
key_layer
=
attention
.
transpose_for_scores
(
attention
.
key
(
embedding_output
))
query_layer
=
attention
.
transpose_for_scores
(
attention
.
query
(
embedding_output
))
attention_scores
=
torch
.
matmul
(
query_layer
,
key_layer
.
transpose
(
-
1
,
-
2
))
attention_scores
/=
math
.
sqrt
(
key_layer
.
size
()[
-
1
])
fused_softmax
=
(
FusedScaleMaskSoftmax
(
input_in_fp16
=
True
,
input_in_bf16
=
False
,
mask_func
=
attention_mask_func
,
scale
=
None
,
softmax_in_fp32
=
False
,
attn_mask_type
=
AttnMaskType
.
padding
,
scaled_masked_softmax_fusion
=
True
,
)
.
cuda
()
.
half
()
)
fused_softmax_output
=
fused_softmax
(
attention_scores
,
(
mask
!=
0
),
)
torch_softmax
=
(
FusedScaleMaskSoftmax
(
input_in_fp16
=
True
,
input_in_bf16
=
False
,
mask_func
=
attention_mask_func
,
scale
=
None
,
softmax_in_fp32
=
False
,
attn_mask_type
=
AttnMaskType
.
padding
,
scaled_masked_softmax_fusion
=
False
,
)
.
cuda
()
.
half
()
)
torch_softmax_output
=
torch_softmax
(
attention_scores
,
(
mask
!=
0
),
)
test_result
=
(
fused_softmax_output
-
torch_softmax_output
).
abs
()
while
test_result
.
dim
()
!=
1
:
test_result
=
test_result
.
mean
(
dim
=-
1
)
diff
=
test_result
.
mean
(
dim
=-
1
)
if
diff
<=
1e-3
:
print
(
f
"
\n
[Success] test_fused_softmax"
f
"
\n
> mean_difference=
{
diff
}
"
f
"
\n
> fused_values=
{
fused_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
"
f
"
\n
> torch_values=
{
torch_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
"
)
else
:
print
(
f
"
\n
[Fail] test_fused_softmax"
f
"
\n
> mean_difference=
{
diff
}
, "
f
"
\n
> fused_values=
{
fused_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
, "
f
"
\n
> torch_values=
{
torch_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
"
)
def
test_fused_upper_triangle_mask_softmax
():
gpt
=
GPT2Model
.
from_pretrained
(
"gpt2"
).
cuda
().
half
()
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"gpt2"
)
test_text
=
(
"Hello. How are you? I am fine thank you and you? yes Good. "
"hi hi hi hi hi hi hi"
# 24
)
tokens
=
tokenizer
(
[
test_text
]
*
4
,
return_tensors
=
"pt"
,
)
attention_mask
=
tokens
[
"attention_mask"
].
cuda
()
attention_mask
=
attention_mask
.
view
(
attention_mask
.
size
(
0
),
-
1
)
attention_mask
=
attention_mask
[:,
None
,
None
,
:]
attention_mask
=
(
1.0
-
attention_mask
)
*
-
10000.0
attention_mask
=
attention_mask
.
repeat
(
1
,
1
,
attention_mask
.
size
()[
-
1
],
1
)
attn
=
gpt
.
h
[
0
]
hidden_states
=
gpt
.
wte
(
tokens
[
"input_ids"
].
cuda
())
q
,
k
,
v
=
attn
.
attn
.
c_attn
(
hidden_states
).
split
(
768
,
dim
=-
1
)
q
=
attn
.
attn
.
_split_heads
(
q
,
attn
.
attn
.
num_heads
,
attn
.
attn
.
head_dim
)
k
=
attn
.
attn
.
_split_heads
(
k
,
attn
.
attn
.
num_heads
,
attn
.
attn
.
head_dim
)
attn_weights
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
))
sq
,
sk
=
q
.
size
(
-
2
),
k
.
size
(
-
2
)
causal_mask
=
attn
.
attn
.
bias
[:,
:,
sk
-
sq
:
sk
,
:
sk
].
bool
()
total_mask
=
~
(
causal_mask
&
(
attention_mask
==
0
))
"""
tensor([[[[False, True, True, ..., True, True, True],
[False, False, True, ..., True, True, True],
[False, False, False, ..., True, True, True],
...,
[False, False, False, ..., False, True, True],
[False, False, False, ..., False, False, True],
[False, False, False, ..., False, False, False]]]
"""
fused_softmax
=
(
FusedScaleMaskSoftmax
(
input_in_fp16
=
True
,
input_in_bf16
=
False
,
mask_func
=
attention_mask_func
,
scale
=
None
,
softmax_in_fp32
=
False
,
attn_mask_type
=
AttnMaskType
.
causal
,
scaled_masked_softmax_fusion
=
True
,
)
.
cuda
()
.
half
()
)
fused_softmax_output
=
fused_softmax
(
attn_weights
,
total_mask
,
)
torch_softmax
=
(
FusedScaleMaskSoftmax
(
input_in_fp16
=
True
,
input_in_bf16
=
False
,
mask_func
=
attention_mask_func
,
scale
=
None
,
softmax_in_fp32
=
False
,
attn_mask_type
=
AttnMaskType
.
causal
,
scaled_masked_softmax_fusion
=
False
,
)
.
cuda
()
.
half
()
)
torch_softmax_output
=
torch_softmax
(
attn_weights
,
total_mask
,
)
test_result
=
(
fused_softmax_output
-
torch_softmax_output
).
abs
()
while
test_result
.
dim
()
!=
1
:
test_result
=
test_result
.
mean
(
dim
=-
1
)
diff
=
test_result
.
mean
(
dim
=-
1
)
if
diff
<=
1e-3
:
print
(
f
"
\n
[Success] test_fused_upper_triangle_mask_softmax"
f
"
\n
> mean_difference=
{
diff
}
"
f
"
\n
> fused_values=
{
fused_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
"
f
"
\n
> torch_values=
{
torch_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
"
)
else
:
print
(
f
"
\n
[Fail] test_fused_upper_triangle_mask_softmax"
f
"
\n
> mean_difference=
{
diff
}
, "
f
"
\n
> fused_values=
{
fused_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
, "
f
"
\n
> torch_values=
{
torch_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
"
)
def
test_layer_norm
():
bert
=
BertModel
.
from_pretrained
(
"bert-base-cased"
).
cuda
().
half
()
tokenizer
=
BertTokenizer
.
from_pretrained
(
"bert-base-cased"
)
test_text
=
(
"Hello. How are you? I am fine thank you and you? yes Good. "
"hi hi hi hi hi hi hi hi hi hi hi hi hi"
# 32
)
tokens
=
tokenizer
(
[
test_text
]
*
4
,
return_tensors
=
"pt"
,
)
# [bsz, seq_len, d_model]
embedding_output
=
(
bert
.
embeddings
(
input_ids
=
tokens
[
"input_ids"
].
cuda
(),
position_ids
=
None
,
token_type_ids
=
tokens
[
"token_type_ids"
].
cuda
(),
inputs_embeds
=
None
,
past_key_values_length
=
0
,
)
.
cuda
()
.
half
()
)
fused_layernorm_layer
=
(
MixedFusedLayerNorm
(
normalized_shape
=
embedding_output
.
size
(
-
1
)).
cuda
().
half
()
)
torch_layernorm_layer
=
(
LayerNorm
(
normalized_shape
=
embedding_output
.
size
(
-
1
)).
cuda
().
half
()
)
fused_output
=
fused_layernorm_layer
(
embedding_output
)
torch_output
=
torch_layernorm_layer
(
embedding_output
)
test_result
=
(
fused_output
-
torch_output
).
abs
()
while
test_result
.
dim
()
!=
1
:
test_result
=
test_result
.
mean
(
dim
=-
1
)
diff
=
test_result
.
mean
(
dim
=-
1
)
if
diff
<=
1e-3
:
print
(
f
"
\n
[Success] test_layer_norm"
f
"
\n
> mean_difference=
{
diff
}
"
f
"
\n
> fused_values=
{
fused_output
[
-
1
][
-
1
][:
5
].
tolist
()
}
"
f
"
\n
> torch_values=
{
torch_output
[
-
1
][
-
1
][:
5
].
tolist
()
}
"
)
else
:
print
(
f
"
\n
[Fail] test_layer_norm"
f
"
\n
> mean_difference=
{
diff
}
, "
f
"
\n
> fused_values=
{
fused_output
[
-
1
][
-
1
][:
5
].
tolist
()
}
, "
f
"
\n
> torch_values=
{
torch_output
[
-
1
][
-
1
][:
5
].
tolist
()
}
"
)
if
__name__
==
"__main__"
:
try
:
from
transformers
import
BertTokenizer
,
GPT2Tokenizer
from
transformers.models.bert.modeling_bert
import
BertModel
from
transformers.models.gpt2.modeling_gpt2
import
GPT2Model
import
transformers
transformers
.
logging
.
set_verbosity
(
transformers
.
logging
.
FATAL
,
)
except
:
print
(
"
\n
[Fail] Please install `transformers` package to test fused kernels
\n
"
)
exit
(
-
1
)
test_load_fused_kernels
()
test_fused_softmax
()
test_fused_upper_triangle_mask_softmax
()
test_layer_norm
()
megatron/initialize.py
View file @
feea48cd
...
@@ -177,10 +177,6 @@ def _initialize_distributed():
...
@@ -177,10 +177,6 @@ def _initialize_distributed():
args
.
local_rank
=
device
args
.
local_rank
=
device
torch
.
cuda
.
set_device
(
device
)
torch
.
cuda
.
set_device
(
device
)
# Call the init process
# Call the init process
init_method
=
'tcp://'
master_ip
=
os
.
getenv
(
'MASTER_ADDR'
,
'localhost'
)
master_port
=
os
.
getenv
(
'MASTER_PORT'
,
'6000'
)
init_method
+=
master_ip
+
':'
+
master_port
torch
.
distributed
.
init_process_group
(
torch
.
distributed
.
init_process_group
(
backend
=
args
.
distributed_backend
,
backend
=
args
.
distributed_backend
,
world_size
=
args
.
world_size
,
rank
=
args
.
rank
,
world_size
=
args
.
world_size
,
rank
=
args
.
rank
,
...
...
megatron/model/fused_softmax.py
View file @
feea48cd
...
@@ -13,7 +13,9 @@
...
@@ -13,7 +13,9 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
torch
import
torch
import
torch.nn
as
nn
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.enums
import
AttnMaskType
...
@@ -30,10 +32,10 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
...
@@ -30,10 +32,10 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
import
scaled_upper_triang_masked_softmax_cuda
import
scaled_upper_triang_masked_softmax_cuda
scale_t
=
torch
.
tensor
([
scale
])
scale_t
=
torch
.
tensor
([
scale
])
softmax_results
=
scaled_upper_triang_masked_softmax_cuda
.
forward
(
softmax_results
=
scaled_upper_triang_masked_softmax_cuda
.
forward
(
inputs
,
scale_t
[
0
]
inputs
,
scale_t
[
0
]
)
)
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
return
softmax_results
return
softmax_results
...
@@ -42,10 +44,10 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
...
@@ -42,10 +44,10 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
import
scaled_upper_triang_masked_softmax_cuda
import
scaled_upper_triang_masked_softmax_cuda
softmax_results
,
scale_t
=
ctx
.
saved_tensors
softmax_results
,
scale_t
=
ctx
.
saved_tensors
input_grads
=
scaled_upper_triang_masked_softmax_cuda
.
backward
(
input_grads
=
scaled_upper_triang_masked_softmax_cuda
.
backward
(
output_grads
,
softmax_results
,
scale_t
[
0
]
output_grads
,
softmax_results
,
scale_t
[
0
]
)
)
return
input_grads
,
None
return
input_grads
,
None
...
@@ -63,9 +65,7 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
...
@@ -63,9 +65,7 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
scale_t
=
torch
.
tensor
([
scale
])
scale_t
=
torch
.
tensor
([
scale
])
softmax_results
=
scaled_masked_softmax_cuda
.
forward
(
softmax_results
=
scaled_masked_softmax_cuda
.
forward
(
inputs
,
mask
,
scale_t
[
0
])
inputs
,
mask
,
scale_t
[
0
]
)
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
return
softmax_results
return
softmax_results
...
@@ -81,16 +81,18 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
...
@@ -81,16 +81,18 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
return
input_grads
,
None
,
None
return
input_grads
,
None
,
None
class
FusedScaleMaskSoftmax
(
torch
.
nn
.
Module
):
class
FusedScaleMaskSoftmax
(
nn
.
Module
):
"""
"""
fused operation: scaling + mask + softmax
fused operation: scaling + mask + softmax
Arguments:
Arguments:
input_in_fp16: flag to indicate if input in fp16 data format.
input_in_fp16: flag to indicate if input in fp16 data format.
input_in_bf16: flag to indicate if input in bf16 data format.
attn_mask_type: attention mask type (pad or causal)
attn_mask_type: attention mask type (pad or causal)
scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion
mask_func: mask function to be applied.
mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
scale: scaling factor used in input tensor scaling.
scale: scaling factor used in input tensor scaling.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -106,8 +108,9 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
...
@@ -106,8 +108,9 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
super
(
FusedScaleMaskSoftmax
,
self
).
__init__
()
super
(
FusedScaleMaskSoftmax
,
self
).
__init__
()
self
.
input_in_fp16
=
input_in_fp16
self
.
input_in_fp16
=
input_in_fp16
self
.
input_in_bf16
=
input_in_bf16
self
.
input_in_bf16
=
input_in_bf16
assert
not
(
self
.
input_in_fp16
and
self
.
input_in_bf16
),
\
assert
not
(
'both fp16 and bf16 flags cannot be active at the same time.'
self
.
input_in_fp16
and
self
.
input_in_bf16
),
"both fp16 and bf16 flags cannot be active at the same time."
self
.
input_in_float16
=
self
.
input_in_fp16
or
self
.
input_in_bf16
self
.
input_in_float16
=
self
.
input_in_fp16
or
self
.
input_in_bf16
self
.
attn_mask_type
=
attn_mask_type
self
.
attn_mask_type
=
attn_mask_type
self
.
scaled_masked_softmax_fusion
=
scaled_masked_softmax_fusion
self
.
scaled_masked_softmax_fusion
=
scaled_masked_softmax_fusion
...
@@ -118,47 +121,72 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
...
@@ -118,47 +121,72 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
assert
(
assert
(
self
.
scale
is
None
or
softmax_in_fp32
self
.
scale
is
None
or
softmax_in_fp32
),
"softmax should be in fp32 when scaled"
),
"softmax should be in fp32 when scaled"
def
forward
(
self
,
input
,
mask
):
def
forward
(
self
,
input
,
mask
):
# [b, np, sq, sk]
# [b, np, sq, sk]
assert
input
.
dim
()
==
4
assert
input
.
dim
()
==
4
data_size
=
input
.
size
()
query_seq_len
=
data_size
[
-
2
]
if
self
.
is_kernel_available
(
mask
,
*
input
.
size
()):
key_seq_len
=
data_size
[
-
1
]
return
self
.
forward_fused_softmax
(
input
,
mask
)
attn_batch_size
=
data_size
[
0
]
*
data_size
[
1
]
# constraints on various tensor dimensions to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint
=
key_seq_len
>
16
and
key_seq_len
<=
2048
and
\
query_seq_len
%
4
==
0
and
attn_batch_size
%
4
==
0
# invoke custom kernel
if
self
.
input_in_float16
and
mask
is
not
None
and
\
custom_kernel_constraint
and
self
.
scaled_masked_softmax_fusion
:
scale
=
self
.
scale
if
self
.
scale
is
not
None
else
1.0
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
assert
query_seq_len
==
key_seq_len
,
\
"causal mask is only for self attention"
input
=
input
.
view
(
-
1
,
query_seq_len
,
key_seq_len
)
probs
=
ScaledUpperTriangMaskedSoftmax
.
apply
(
input
,
scale
)
probs
=
probs
.
view
(
*
data_size
)
else
:
assert
self
.
attn_mask_type
==
AttnMaskType
.
padding
probs
=
ScaledMaskedSoftmax
.
apply
(
input
,
mask
,
scale
)
else
:
else
:
if
self
.
input_in_float16
and
self
.
softmax_in_fp32
:
return
self
.
forward_torch_softmax
(
input
,
mask
)
input
=
input
.
float
()
def
is_kernel_available
(
self
,
mask
,
b
,
np
,
sq
,
sk
):
attn_batches
=
b
*
np
if
(
self
.
scaled_masked_softmax_fusion
# user want to fuse
and
self
.
input_in_float16
# input must be fp16
and
mask
is
not
None
# mask tensor must not be None
and
16
<
sk
<=
2048
# sk must be 16 ~ 2048
and
sq
%
4
==
0
# sq must be divisor of 4
and
attn_batches
%
4
==
0
# np * b must be divisor of 4
):
if
0
<=
sk
<=
2048
:
batch_per_block
=
self
.
get_batch_per_block
(
sq
,
sk
,
b
,
np
)
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
if
attn_batches
%
batch_per_block
==
0
:
return
True
else
:
if
sq
%
batch_per_block
==
0
:
return
True
return
False
if
self
.
scale
is
not
None
:
def
forward_fused_softmax
(
self
,
input
,
mask
):
input
=
input
*
self
.
scale
b
,
np
,
sq
,
sk
=
input
.
size
()
mask_output
=
self
.
mask_func
(
input
,
mask
)
if
mask
is
not
None
else
input
scale
=
self
.
scale
if
self
.
scale
is
not
None
else
1.0
probs
=
torch
.
nn
.
Softmax
(
dim
=-
1
)(
mask_output
)
if
self
.
input_in_float16
and
self
.
softmax_in_fp32
:
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
if
self
.
input_in_fp16
:
assert
sq
==
sk
,
"causal mask is only for self attention"
probs
=
probs
.
half
()
else
:
# input is 3D tensor (attn_batches, sq, sk)
probs
=
probs
.
bfloat16
()
input
=
input
.
view
(
-
1
,
sq
,
sk
)
probs
=
ScaledUpperTriangMaskedSoftmax
.
apply
(
input
,
scale
)
return
probs
.
view
(
b
,
np
,
sq
,
sk
)
else
:
# input is 4D tensor (b, np, sq, sk)
return
ScaledMaskedSoftmax
.
apply
(
input
,
mask
,
scale
)
def
forward_torch_softmax
(
self
,
input
,
mask
):
if
self
.
input_in_float16
and
self
.
softmax_in_fp32
:
input
=
input
.
float
()
if
self
.
scale
is
not
None
:
input
=
input
*
self
.
scale
mask_output
=
self
.
mask_func
(
input
,
mask
)
if
mask
is
not
None
else
input
probs
=
torch
.
nn
.
Softmax
(
dim
=-
1
)(
mask_output
)
if
self
.
input_in_float16
and
self
.
softmax_in_fp32
:
if
self
.
input_in_fp16
:
probs
=
probs
.
half
()
else
:
probs
=
probs
.
bfloat16
()
return
probs
return
probs
@
staticmethod
def
get_batch_per_block
(
sq
,
sk
,
b
,
np
):
import
scaled_masked_softmax_cuda
return
scaled_masked_softmax_cuda
.
get_batch_per_block
(
sq
,
sk
,
b
,
np
)
megatron/model/transformer.py
View file @
feea48cd
...
@@ -53,8 +53,7 @@ class ParallelMLP(MegatronModule):
...
@@ -53,8 +53,7 @@ class ParallelMLP(MegatronModule):
MLP will take the input with h hidden state, project it to 4*h
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension. At the end, dropout is also
state back into h hidden dimension.
applied.
"""
"""
def
__init__
(
self
,
init_method
,
output_layer_init_method
):
def
__init__
(
self
,
init_method
,
output_layer_init_method
):
...
@@ -84,7 +83,6 @@ class ParallelMLP(MegatronModule):
...
@@ -84,7 +83,6 @@ class ParallelMLP(MegatronModule):
init_method
=
output_layer_init_method
,
init_method
=
output_layer_init_method
,
skip_bias_add
=
True
)
skip_bias_add
=
True
)
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
):
# [s, b, 4hp]
# [s, b, 4hp]
...
@@ -544,8 +542,8 @@ class ParallelTransformer(MegatronModule):
...
@@ -544,8 +542,8 @@ class ParallelTransformer(MegatronModule):
self
.
input_tensor
=
None
self
.
input_tensor
=
None
# Store activation checkpoiting flag.
# Store activation checkpoiting flag.
self
.
checkpoint_
activations
=
args
.
checkpoint_
activations
self
.
activations_checkpoint_method
=
args
.
activations
_
checkpoint_
method
self
.
checkpoint_num_layers
=
args
.
checkpoint_num_layers
self
.
activations_
checkpoint_num_layers
=
args
.
activations_
checkpoint_num_layers
# Number of layers.
# Number of layers.
assert
args
.
num_layers
%
mpu
.
get_pipeline_model_parallel_world_size
()
==
0
,
\
assert
args
.
num_layers
%
mpu
.
get_pipeline_model_parallel_world_size
()
==
0
,
\
...
@@ -611,12 +609,31 @@ class ParallelTransformer(MegatronModule):
...
@@ -611,12 +609,31 @@ class ParallelTransformer(MegatronModule):
# Make sure memory is freed.
# Make sure memory is freed.
mpu
.
reset_checkpointed_activations_memory_buffer
()
mpu
.
reset_checkpointed_activations_memory_buffer
()
l
=
0
while
l
<
self
.
num_layers
:
if
self
.
activations_checkpoint_method
==
'uniform'
:
hidden_states
=
mpu
.
checkpoint
(
# Uniformly divide the total number of Transformer layers and checkpoint
custom
(
l
,
l
+
self
.
checkpoint_num_layers
),
# the input activation of each divided chunk.
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
# A method to further reduce memory usage reducing checkpoints.
l
+=
self
.
checkpoint_num_layers
l
=
0
while
l
<
self
.
num_layers
:
hidden_states
=
mpu
.
checkpoint
(
custom
(
l
,
l
+
self
.
activations_checkpoint_num_layers
),
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
l
+=
self
.
activations_checkpoint_num_layers
elif
self
.
activations_checkpoint_method
==
'block'
:
# Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest.
# A method fully use the device memory removing redundant re-computation.
for
l
in
range
(
self
.
num_layers
):
if
l
<
self
.
activations_checkpoint_num_layers
:
hidden_states
=
mpu
.
checkpoint
(
custom
(
l
,
l
+
1
),
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
else
:
hidden_states
=
custom
(
l
,
l
+
1
)(
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
else
:
raise
ValueError
(
"Invalid activation checkpoint method."
)
return
hidden_states
return
hidden_states
...
@@ -639,7 +656,7 @@ class ParallelTransformer(MegatronModule):
...
@@ -639,7 +656,7 @@ class ParallelTransformer(MegatronModule):
'for not None values in layer_past, '
\
'for not None values in layer_past, '
\
'expected get_key_value to be set'
'expected get_key_value to be set'
if
get_key_value
:
if
get_key_value
:
assert
not
self
.
checkpoint_activati
on
s
,
\
assert
self
.
activations_checkpoint_method
is
N
on
e
,
\
'get_key_value does not work with '
\
'get_key_value does not work with '
\
'activation checkpointing'
'activation checkpointing'
...
@@ -658,7 +675,7 @@ class ParallelTransformer(MegatronModule):
...
@@ -658,7 +675,7 @@ class ParallelTransformer(MegatronModule):
if
encoder_output
is
not
None
:
if
encoder_output
is
not
None
:
encoder_output
=
encoder_output
.
transpose
(
0
,
1
).
contiguous
()
encoder_output
=
encoder_output
.
transpose
(
0
,
1
).
contiguous
()
if
self
.
checkpoint_activati
on
s
:
if
self
.
activations_checkpoint_method
is
not
N
on
e
:
hidden_states
=
self
.
_checkpointed_forward
(
hidden_states
,
hidden_states
=
self
.
_checkpointed_forward
(
hidden_states
,
attention_mask
,
attention_mask
,
encoder_output
,
encoder_output
,
...
...
megatron/mpu/initialize.py
View file @
feea48cd
...
@@ -356,9 +356,13 @@ def get_data_parallel_rank():
...
@@ -356,9 +356,13 @@ def get_data_parallel_rank():
def
destroy_model_parallel
():
def
destroy_model_parallel
():
"""Set the groups to none."""
"""Set the groups to none."""
global
_MODEL_PARALLEL_GROUP
_MODEL_PARALLEL_GROUP
=
None
global
_TENSOR_MODEL_PARALLEL_GROUP
global
_TENSOR_MODEL_PARALLEL_GROUP
_TENSOR_MODEL_PARALLEL_GROUP
=
None
_TENSOR_MODEL_PARALLEL_GROUP
=
None
global
_PIPELINE_MODEL_PARALLEL_GROUP
global
_PIPELINE_MODEL_PARALLEL_GROUP
_PIPELINE_MODEL_PARALLEL_GROUP
=
None
_PIPELINE_MODEL_PARALLEL_GROUP
=
None
global
_DATA_PARALLEL_GROUP
global
_DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP
=
None
_DATA_PARALLEL_GROUP
=
None
global
_EMBEDDING_GROUP
_EMBEDDING_GROUP
=
None
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment