Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
c92f10bd
Commit
c92f10bd
authored
Jan 11, 2023
by
Jared Casper
Browse files
Merge branch 'main' into tridao-flashattn
parents
9200e43a
b7071993
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
212 additions
and
53 deletions
+212
-53
megatron/arguments.py
megatron/arguments.py
+40
-0
megatron/fused_kernels/__init__.py
megatron/fused_kernels/__init__.py
+9
-3
megatron/model/transformer.py
megatron/model/transformer.py
+147
-43
megatron/training.py
megatron/training.py
+7
-0
pretrain_gpt.py
pretrain_gpt.py
+3
-1
tasks/zeroshot_gpt/evaluate.py
tasks/zeroshot_gpt/evaluate.py
+6
-6
No files found.
megatron/arguments.py
View file @
c92f10bd
...
@@ -28,6 +28,7 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
...
@@ -28,6 +28,7 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
parser
=
_add_vision_args
(
parser
)
parser
=
_add_vision_args
(
parser
)
parser
=
_add_logging_args
(
parser
)
parser
=
_add_logging_args
(
parser
)
parser
=
_add_inference_args
(
parser
)
parser
=
_add_inference_args
(
parser
)
parser
=
_add_transformer_engine_args
(
parser
)
# Custom arguments.
# Custom arguments.
if
extra_args_provider
is
not
None
:
if
extra_args_provider
is
not
None
:
...
@@ -304,6 +305,18 @@ def validate_args(args, defaults={}):
...
@@ -304,6 +305,18 @@ def validate_args(args, defaults={}):
'v1.10 and above (Nvidia Pytorch container >= 21.07). Current '
\
'v1.10 and above (Nvidia Pytorch container >= 21.07). Current '
\
'pytorch version is v%s.%s.'
%
(
TORCH_MAJOR
,
TORCH_MINOR
)
'pytorch version is v%s.%s.'
%
(
TORCH_MAJOR
,
TORCH_MINOR
)
# Tranformer-Engine/FP8 related checking
if
args
.
fp8_e4m3
or
args
.
fp8_hybrid
:
assert
args
.
transformer_impl
==
'transformer_engine'
,
\
'transformer-engine required for fp8 training and inference'
assert
not
(
args
.
fp8_e4m3
and
args
.
fp8_hybrid
),
\
'cannot train with both fp8 e4m3 and hybrid formatting'
if
args
.
fp16
:
assert
args
.
transformer_impl
==
'local'
,
\
'transformer-engine not yet approved for fp16 training and inference'
if
args
.
recompute_granularity
==
'selective'
:
if
args
.
recompute_granularity
==
'selective'
:
assert
args
.
recompute_method
is
None
,
\
assert
args
.
recompute_method
is
None
,
\
'recompute method is not yet supported for '
\
'recompute method is not yet supported for '
\
...
@@ -355,6 +368,33 @@ def _check_arg_is_not_none(args, arg):
...
@@ -355,6 +368,33 @@ def _check_arg_is_not_none(args, arg):
assert
getattr
(
args
,
arg
)
is
not
None
,
'{} argument is None'
.
format
(
arg
)
assert
getattr
(
args
,
arg
)
is
not
None
,
'{} argument is None'
.
format
(
arg
)
def
_add_transformer_engine_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'Transformer-Engine'
)
group
.
add_argument
(
'--fp8-e4m3'
,
action
=
'store_true'
,
help
=
'E4M3 TransformerLayer'
,
dest
=
'fp8_e4m3'
)
group
.
add_argument
(
'--fp8-hybrid'
,
action
=
'store_true'
,
help
=
'Hybrid FP8 TransformerLayer'
,
dest
=
'fp8_hybrid'
)
group
.
add_argument
(
'--no-fp8-wgrad'
,
action
=
'store_false'
,
help
=
'Execute wgrad in higher precision even for FP8 runs'
,
dest
=
'fp8_wgrad'
)
group
.
add_argument
(
'--fp8-margin'
,
type
=
int
,
default
=
0
,
help
=
'Scaling margin for fp8'
,
dest
=
'fp8_margin'
)
group
.
add_argument
(
'--fp8-interval'
,
type
=
int
,
default
=
1
,
help
=
'Scaling update interval for fp8'
,
dest
=
'fp8_interval'
)
group
.
add_argument
(
'--transformer-impl'
,
default
=
'local'
,
choices
=
[
'local'
,
'transformer_engine'
],
help
=
'Which Transformer implementation to use.'
,
dest
=
'transformer_impl'
)
group
.
add_argument
(
'--fp8-amax-history-len'
,
type
=
int
,
default
=
1
,
help
=
'Number of steps for which amax history is recorded per tensor'
,
dest
=
'fp8_amax_history_len'
)
group
.
add_argument
(
'--fp8-amax-compute-algo'
,
default
=
'most_recent'
,
choices
=
[
'most_recent'
,
'max'
],
help
=
'Algorithm for computing amax from history'
,
dest
=
'fp8_amax_compute_algo'
)
return
parser
def
_add_inference_args
(
parser
):
def
_add_inference_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'inference'
)
group
=
parser
.
add_argument_group
(
title
=
'inference'
)
...
...
megatron/fused_kernels/__init__.py
View file @
c92f10bd
...
@@ -18,11 +18,14 @@ def load(args):
...
@@ -18,11 +18,14 @@ def load(args):
# Check if cuda 11 is installed for compute capability 8.0
# Check if cuda 11 is installed for compute capability 8.0
cc_flag
=
[]
cc_flag
=
[]
_
,
bare_metal_major
,
_
=
_get_cuda_bare_metal_version
(
_
,
bare_metal_major
,
bare_metal_minor
=
_get_cuda_bare_metal_version
(
cpp_extension
.
CUDA_HOME
)
cpp_extension
.
CUDA_HOME
)
if
int
(
bare_metal_major
)
>=
11
:
if
int
(
bare_metal_major
)
>=
11
:
cc_flag
.
append
(
'-gencode'
)
cc_flag
.
append
(
'-gencode'
)
cc_flag
.
append
(
'arch=compute_80,code=sm_80'
)
cc_flag
.
append
(
'arch=compute_80,code=sm_80'
)
if
int
(
bare_metal_minor
)
>=
7
:
cc_flag
.
append
(
'-gencode'
)
cc_flag
.
append
(
'arch=compute_90,code=sm_90'
)
# Build path
# Build path
srcpath
=
pathlib
.
Path
(
__file__
).
parent
.
absolute
()
srcpath
=
pathlib
.
Path
(
__file__
).
parent
.
absolute
()
...
@@ -75,11 +78,14 @@ def load(args):
...
@@ -75,11 +78,14 @@ def load(args):
# Mixed precision fused layer norm.
# Mixed precision fused layer norm.
# =================================
# =================================
extra_hopper_flags
=
[
'-U__CUDA_NO_HALF_OPERATORS__'
,
'-U__CUDA_NO_HALF_CONVERSIONS__'
]
extra_cuda_flags
=
[
'-maxrregcount=50'
]
extra_cuda_flags
=
[
'-maxrregcount=50'
]
sources
=
[
srcpath
/
'layer_norm_cuda.cpp'
,
sources
=
[
srcpath
/
'layer_norm_cuda.cpp'
,
srcpath
/
'layer_norm_cuda_kernel.cu'
]
srcpath
/
'layer_norm_cuda_kernel.cu'
]
fused_mix_prec_layer_norm_cuda
=
_cpp_extention_load_helper
(
fused_mix_prec_layer_norm_cuda
=
_cpp_extention_load_helper
(
"fused_mix_prec_layer_norm_cuda"
,
sources
,
extra_cuda_flags
)
"fused_mix_prec_layer_norm_cuda"
,
sources
,
extra_cuda_flags
+
extra_hopper_flags
)
# =================================
# =================================
# Fused gradient accumulation to weight gradient computation of linear layer
# Fused gradient accumulation to weight gradient computation of linear layer
...
@@ -89,7 +95,7 @@ def load(args):
...
@@ -89,7 +95,7 @@ def load(args):
sources
=
[
srcpath
/
'fused_weight_gradient_dense.cpp'
,
sources
=
[
srcpath
/
'fused_weight_gradient_dense.cpp'
,
srcpath
/
'fused_weight_gradient_dense.cu'
]
srcpath
/
'fused_weight_gradient_dense.cu'
]
fused_dense_cuda
=
_cpp_extention_load_helper
(
fused_dense_cuda
=
_cpp_extention_load_helper
(
"fused_dense_cuda"
,
sources
,
[]
)
"fused_dense_cuda"
,
sources
,
extra_hopper_flags
)
def
_get_cuda_bare_metal_version
(
cuda_dir
):
def
_get_cuda_bare_metal_version
(
cuda_dir
):
...
...
megatron/model/transformer.py
View file @
c92f10bd
...
@@ -6,7 +6,7 @@ from contextlib import nullcontext
...
@@ -6,7 +6,7 @@ from contextlib import nullcontext
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
megatron
import
get_timers
,
get_args
,
core
from
megatron
import
get_timers
,
get_args
,
core
,
get_num_microbatches
from
.module
import
MegatronModule
from
.module
import
MegatronModule
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron.model.enums
import
AttnMaskType
,
ModelType
,
LayerType
,
AttnType
from
megatron.model.enums
import
AttnMaskType
,
ModelType
,
LayerType
,
AttnType
...
@@ -25,7 +25,6 @@ try:
...
@@ -25,7 +25,6 @@ try:
except
ImportError
:
except
ImportError
:
flash_attn_unpadded_func
=
None
flash_attn_unpadded_func
=
None
""" We use the following notation throughout this file:
""" We use the following notation throughout this file:
h: hidden size
h: hidden size
n: number of attention heads
n: number of attention heads
...
@@ -890,6 +889,7 @@ class ParallelTransformer(MegatronModule):
...
@@ -890,6 +889,7 @@ class ParallelTransformer(MegatronModule):
self
.
post_process
=
post_process
self
.
post_process
=
post_process
self
.
input_tensor
=
None
self
.
input_tensor
=
None
self
.
drop_path_rate
=
drop_path_rate
self
.
drop_path_rate
=
drop_path_rate
self
.
transformer_impl
=
args
.
transformer_impl
# Store activation checkpoiting flag.
# Store activation checkpoiting flag.
self
.
recompute_granularity
=
args
.
recompute_granularity
self
.
recompute_granularity
=
args
.
recompute_granularity
...
@@ -900,6 +900,31 @@ class ParallelTransformer(MegatronModule):
...
@@ -900,6 +900,31 @@ class ParallelTransformer(MegatronModule):
self
.
sequence_parallel
=
args
.
sequence_parallel
self
.
sequence_parallel
=
args
.
sequence_parallel
# Transformer Engine Init.
if
self
.
transformer_impl
==
'transformer_engine'
:
global
transformer_engine
import
transformer_engine
self
.
use_fp8
=
args
.
fp8_e4m3
or
args
.
fp8_hybrid
self
.
fp8_recipe
=
None
self
.
fp8_group
=
mpu
.
get_data_parallel_group
()
if
self
.
use_fp8
:
if
args
.
fp8_e4m3
:
fp8_format
=
transformer_engine
.
common
.
recipe
.
Format
.
E4M3
elif
args
.
fp8_hybrid
:
fp8_format
=
transformer_engine
.
common
.
recipe
.
Format
.
HYBRID
self
.
fp8_recipe
=
transformer_engine
.
common
.
recipe
.
DelayedScaling
(
margin
=
args
.
fp8_margin
,
interval
=
args
.
fp8_interval
,
fp8_format
=
fp8_format
,
amax_history_len
=
args
.
fp8_amax_history_len
,
amax_compute_algo
=
args
.
fp8_amax_compute_algo
,
override_linear_precision
=
(
False
,
False
,
not
args
.
fp8_wgrad
),
)
self
.
num_microbatches_in_previous_step
=
-
1
self
.
microbatch_count
=
0
self
.
checkpoint_core_attention
=
args
.
recompute_granularity
==
'selective'
# Number of layers.
# Number of layers.
self
.
num_layers
=
_get_num_layers
(
self
.
num_layers
=
_get_num_layers
(
args
,
args
,
...
@@ -910,6 +935,7 @@ class ParallelTransformer(MegatronModule):
...
@@ -910,6 +935,7 @@ class ParallelTransformer(MegatronModule):
# Transformer layers.
# Transformer layers.
def
build_layer
(
layer_number
):
def
build_layer
(
layer_number
):
if
args
.
transformer_impl
==
'local'
:
return
ParallelTransformerLayer
(
return
ParallelTransformerLayer
(
init_method
,
init_method
,
output_layer_init_method
,
output_layer_init_method
,
...
@@ -917,6 +943,35 @@ class ParallelTransformer(MegatronModule):
...
@@ -917,6 +943,35 @@ class ParallelTransformer(MegatronModule):
layer_type
=
layer_type
,
layer_type
=
layer_type
,
self_attn_mask_type
=
self_attn_mask_type
,
self_attn_mask_type
=
self_attn_mask_type
,
drop_path_rate
=
self
.
drop_path_rates
[
layer_number
-
1
])
drop_path_rate
=
self
.
drop_path_rates
[
layer_number
-
1
])
else
:
return
transformer_engine
.
pytorch
.
TransformerLayer
(
args
.
hidden_size
,
args
.
ffn_hidden_size
,
args
.
num_attention_heads
,
layernorm_epsilon
=
args
.
layernorm_epsilon
,
hidden_dropout
=
args
.
hidden_dropout
,
attention_dropout
=
args
.
attention_dropout
,
init_method
=
init_method
,
output_layer_init_method
=
output_layer_init_method
,
layer_number
=
layer_number
,
kv_channels
=
args
.
kv_channels
,
self_attn_mask_type
=
self_attn_mask_type
.
name
,
tp_group
=
mpu
.
get_tensor_model_parallel_group
(),
get_rng_state_tracker
=
tensor_parallel
.
get_cuda_rng_tracker
,
fuse_wgrad_accumulation
=
args
.
gradient_accumulation_fusion
,
apply_query_key_layer_scaling
=
args
.
apply_query_key_layer_scaling
,
attention_softmax_in_fp32
=
args
.
attention_softmax_in_fp32
,
seq_length
=
args
.
seq_length
,
micro_batch_size
=
args
.
micro_batch_size
,
sequence_parallel
=
args
.
sequence_parallel
,
params_dtype
=
args
.
params_dtype
,
apply_residual_connection_post_layernorm
=
args
.
apply_residual_connection_post_layernorm
,
output_layernorm
=
False
,
layer_type
=
"encoder"
,
drop_path_rate
=
self
.
drop_path_rates
[
layer_number
-
1
],
set_parallel_mode
=
True
,
fuse_qkv_params
=
True
)
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
assert
args
.
num_layers
%
args
.
virtual_pipeline_model_parallel_size
==
0
,
\
assert
args
.
num_layers
%
args
.
virtual_pipeline_model_parallel_size
==
0
,
\
'num_layers_per_stage must be divisible by '
\
'num_layers_per_stage must be divisible by '
\
...
@@ -976,19 +1031,20 @@ class ParallelTransformer(MegatronModule):
...
@@ -976,19 +1031,20 @@ class ParallelTransformer(MegatronModule):
return
self
.
layers
[
layer_number
]
return
self
.
layers
[
layer_number
]
def
_checkpointed_forward
(
self
,
hidden_states
,
attention_mask
,
def
_checkpointed_forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
):
encoder_output
,
enc_dec_attn_mask
,
is_first_microbatch
):
"""Forward method with activation checkpointing."""
"""Forward method with activation checkpointing."""
def
custom
(
start
,
end
):
def
custom
(
start
,
end
,
is_transformer_engine
=
False
):
def
custom_forward
(
*
inputs
):
def
custom_forward
(
*
args
,
**
kwargs
):
x_
=
inputs
[
0
]
attention_mask
=
inputs
[
1
]
encoder_output
=
inputs
[
2
]
enc_dec_attn_mask
=
inputs
[
3
]
for
index
in
range
(
start
,
end
):
for
index
in
range
(
start
,
end
):
layer
=
self
.
_get_layer
(
index
)
layer
=
self
.
_get_layer
(
index
)
x_
=
layer
(
x_
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
x_
=
layer
(
*
args
,
**
kwargs
)
return
x_
return
x_
def
custom_forward_transformer_engine
(
*
args
,
**
kwargs
):
return
custom_forward
(
*
args
,
is_first_microbatch
=
is_first_microbatch
,
**
kwargs
)
if
not
is_transformer_engine
:
return
custom_forward
return
custom_forward
else
:
return
custom_forward_transformer_engine
if
self
.
recompute_method
==
'uniform'
:
if
self
.
recompute_method
==
'uniform'
:
# Uniformly divide the total number of Transformer layers and checkpoint
# Uniformly divide the total number of Transformer layers and checkpoint
...
@@ -996,10 +1052,19 @@ class ParallelTransformer(MegatronModule):
...
@@ -996,10 +1052,19 @@ class ParallelTransformer(MegatronModule):
# A method to further reduce memory usage reducing checkpoints.
# A method to further reduce memory usage reducing checkpoints.
l
=
0
l
=
0
while
l
<
self
.
num_layers
:
while
l
<
self
.
num_layers
:
if
self
.
transformer_impl
==
'transformer_engine'
:
hidden_states
=
transformer_engine
.
pytorch
.
distributed
.
checkpoint
(
custom
(
l
,
l
+
self
.
recompute_num_layers
,
is_transformer_engine
=
True
),
self
.
distribute_saved_activations
,
tensor_parallel
.
get_cuda_rng_tracker
,
mpu
.
get_tensor_model_parallel_group
(),
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
else
:
hidden_states
=
tensor_parallel
.
checkpoint
(
hidden_states
=
tensor_parallel
.
checkpoint
(
custom
(
l
,
l
+
self
.
recompute_num_layers
),
custom
(
l
,
l
+
self
.
recompute_num_layers
),
self
.
distribute_saved_activations
,
self
.
distribute_saved_activations
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
l
+=
self
.
recompute_num_layers
l
+=
self
.
recompute_num_layers
elif
self
.
recompute_method
==
'block'
:
elif
self
.
recompute_method
==
'block'
:
...
@@ -1008,10 +1073,22 @@ class ParallelTransformer(MegatronModule):
...
@@ -1008,10 +1073,22 @@ class ParallelTransformer(MegatronModule):
# A method fully use the device memory removing redundant re-computation.
# A method fully use the device memory removing redundant re-computation.
for
l
in
range
(
self
.
num_layers
):
for
l
in
range
(
self
.
num_layers
):
if
l
<
self
.
recompute_num_layers
:
if
l
<
self
.
recompute_num_layers
:
if
self
.
transformer_impl
==
'transformer_engine'
:
hidden_states
=
transformer_engine
.
pytorch
.
distributed
.
checkpoint
(
custom
(
l
,
l
+
1
,
is_transformer_engine
=
True
),
self
.
distribute_saved_activations
,
tensor_parallel
.
get_cuda_rng_tracker
,
mpu
.
get_tensor_model_parallel_group
(),
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
else
:
hidden_states
=
tensor_parallel
.
checkpoint
(
hidden_states
=
tensor_parallel
.
checkpoint
(
custom
(
l
,
l
+
1
),
custom
(
l
,
l
+
1
),
self
.
distribute_saved_activations
,
self
.
distribute_saved_activations
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
else
:
if
self
.
transformer_impl
==
'transformer_engine'
:
hidden_states
=
custom
(
l
,
l
+
1
,
is_transformer_engine
=
True
)(
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
else
:
else
:
hidden_states
=
custom
(
l
,
l
+
1
)(
hidden_states
=
custom
(
l
,
l
+
1
)(
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
...
@@ -1071,21 +1148,48 @@ class ParallelTransformer(MegatronModule):
...
@@ -1071,21 +1148,48 @@ class ParallelTransformer(MegatronModule):
rng_context
=
nullcontext
()
rng_context
=
nullcontext
()
with
rng_context
:
with
rng_context
:
# The fp8_autocast context manager is a no-op when enabled=True
# The if...else serves to short circuit name resolution for fp8_autocast
with
transformer_engine
.
pytorch
.
fp8_autocast
(
enabled
=
self
.
use_fp8
,
fp8_recipe
=
self
.
fp8_recipe
,
fp8_group
=
self
.
fp8_group
)
if
self
.
use_fp8
else
nullcontext
():
# Determine if the current iteration is first microbatch
if
self
.
num_microbatches_in_previous_step
!=
get_num_microbatches
():
self
.
microbatch_count
=
0
# Reset count on new batch size rampup interval
self
.
num_microbatches_in_previous_step
=
get_num_microbatches
()
is_first_microbatch
=
self
.
microbatch_count
%
get_num_microbatches
()
==
0
# Forward pass.
# Forward pass.
if
self
.
recompute_granularity
==
'full'
:
if
self
.
recompute_granularity
==
'full'
:
hidden_states
=
self
.
_checkpointed_forward
(
hidden_states
,
hidden_states
=
self
.
_checkpointed_forward
(
hidden_states
,
attention_mask
,
attention_mask
,
encoder_output
,
encoder_output
,
enc_dec_attn_mask
)
enc_dec_attn_mask
,
is_first_microbatch
)
else
:
else
:
forward_kwargs
=
{
'encoder_output'
:
encoder_output
,
'enc_dec_attn_mask'
:
enc_dec_attn_mask
,
'inference_params'
:
inference_params
,
}
if
self
.
transformer_impl
==
'transformer_engine'
:
forward_kwargs
[
'is_first_microbatch'
]
=
is_first_microbatch
forward_kwargs
[
'checkpoint_core_attention'
]
=
self
.
checkpoint_core_attention
for
index
in
range
(
self
.
num_layers
):
for
index
in
range
(
self
.
num_layers
):
layer
=
self
.
_get_layer
(
index
)
layer
=
self
.
_get_layer
(
index
)
hidden_states
=
layer
(
hidden_states
=
layer
(
hidden_states
,
hidden_states
,
attention_mask
,
attention_mask
,
encoder_output
=
encoder_output
,
**
forward_kwargs
)
enc_dec_attn_mask
=
enc_dec_attn_mask
,
inference_params
=
inference_params
)
# Skip counter update for eval and activation checkpointing
if
torch
.
is_grad_enabled
()
and
self
.
training
:
self
.
microbatch_count
+=
1
# Final layer norm.
# Final layer norm.
if
self
.
post_process
and
self
.
post_layer_norm
:
if
self
.
post_process
and
self
.
post_layer_norm
:
...
...
megatron/training.py
View file @
c92f10bd
...
@@ -26,6 +26,7 @@ from megatron.checkpointing import load_checkpoint
...
@@ -26,6 +26,7 @@ from megatron.checkpointing import load_checkpoint
from
megatron.checkpointing
import
save_checkpoint
from
megatron.checkpointing
import
save_checkpoint
from
megatron.model
import
Float16Module
from
megatron.model
import
Float16Module
from
megatron.model
import
ModelType
from
megatron.model
import
ModelType
from
megatron.model
import
GPTModel
from
megatron.optimizer
import
get_megatron_optimizer
from
megatron.optimizer
import
get_megatron_optimizer
from
megatron.initialize
import
initialize_megatron
from
megatron.initialize
import
initialize_megatron
from
megatron.initialize
import
write_args_to_tensorboard
from
megatron.initialize
import
write_args_to_tensorboard
...
@@ -251,6 +252,12 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
...
@@ -251,6 +252,12 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
if
not
isinstance
(
model
,
list
):
if
not
isinstance
(
model
,
list
):
model
=
[
model
]
model
=
[
model
]
# Disallow training and inference with Transformer Engine
# for non-GPT models
args
.
allow_transformer_engine
=
all
([
type
(
m
)
==
GPTModel
for
m
in
model
])
assert
args
.
allow_transformer_engine
or
args
.
transformer_impl
==
'local'
,
\
'Transformer Engine is only approved for GPT models'
# Set tensor model parallel attributes if not set.
# Set tensor model parallel attributes if not set.
# Only parameters that are already tensor model parallel have these
# Only parameters that are already tensor model parallel have these
# attributes set for them. We should make sure the default attributes
# attributes set for them. We should make sure the default attributes
...
...
pretrain_gpt.py
View file @
c92f10bd
...
@@ -113,4 +113,6 @@ if __name__ == "__main__":
...
@@ -113,4 +113,6 @@ if __name__ == "__main__":
pretrain
(
train_valid_test_datasets_provider
,
model_provider
,
pretrain
(
train_valid_test_datasets_provider
,
model_provider
,
ModelType
.
encoder_or_decoder
,
ModelType
.
encoder_or_decoder
,
forward_step
,
args_defaults
=
{
'tokenizer_type'
:
'GPT2BPETokenizer'
})
forward_step
,
args_defaults
=
{
'tokenizer_type'
:
'GPT2BPETokenizer'
}
)
tasks/zeroshot_gpt/evaluate.py
View file @
c92f10bd
...
@@ -9,7 +9,7 @@ import torch
...
@@ -9,7 +9,7 @@ import torch
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
print_rank_0
,
is_last_rank
from
megatron
import
print_rank_0
,
is_last_rank
from
megatron
import
get_tokenizer
from
megatron
import
get_tokenizer
from
megatron.core
import
mpu
from
megatron.core
import
parallel_state
,
tensor_parallel
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
load_checkpoint
from
megatron.model
import
GPTModel
from
megatron.model
import
GPTModel
from
megatron.training
import
get_model
from
megatron.training
import
get_model
...
@@ -90,10 +90,10 @@ def forward_step(batch, model, eval_metric):
...
@@ -90,10 +90,10 @@ def forward_step(batch, model, eval_metric):
send_forward
(
output
)
send_forward
(
output
)
if
mpu
.
is_pipeline_last_stage
():
if
parallel_state
.
is_pipeline_last_stage
():
# For loss, return the unreduced loss.
# For loss, return the unreduced loss.
if
eval_metric
==
'loss'
:
if
eval_metric
==
'loss'
:
losses
=
mpu
.
tensor_parallel
.
vocab_parallel_cross_entropy
(
losses
=
tensor_parallel
.
vocab_parallel_cross_entropy
(
output
.
contiguous
().
float
(),
labels
.
contiguous
())
output
.
contiguous
().
float
(),
labels
.
contiguous
())
loss
=
torch
.
sum
(
loss
=
torch
.
sum
(
losses
.
view
(
-
1
)
*
loss_mask
.
contiguous
().
view
(
-
1
).
float
())
losses
.
view
(
-
1
)
*
loss_mask
.
contiguous
().
view
(
-
1
).
float
())
...
@@ -129,9 +129,9 @@ def evaluate(data_loader, model, eval_metric):
...
@@ -129,9 +129,9 @@ def evaluate(data_loader, model, eval_metric):
output
=
forward_step
(
batch
,
model
,
eval_metric
)
output
=
forward_step
(
batch
,
model
,
eval_metric
)
# Reduce across processes.
# Reduce across processes.
if
mpu
.
is_pipeline_last_stage
():
if
parallel_state
.
is_pipeline_last_stage
():
torch
.
distributed
.
all_reduce
(
output
,
torch
.
distributed
.
all_reduce
(
output
,
group
=
mpu
.
get_data_parallel_group
())
group
=
parallel_state
.
get_data_parallel_group
())
total_output
+=
output
total_output
+=
output
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment