Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
3aca1415
Commit
3aca1415
authored
Apr 29, 2024
by
liangjing
Browse files
Merge branch 'megatron-lm_dtk24.04' into 'main'
Megatron lm dtk24.04 See merge request
!1
parents
0024a5c6
1005e9d3
Pipeline
#1806
passed with stage
Changes
204
Pipelines
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1793 additions
and
16 deletions
+1793
-16
megatron/core/dist_checkpointing/strategies/__init__.py
megatron/core/dist_checkpointing/strategies/__init__.py
+16
-0
megatron/core/dist_checkpointing/strategies/base.py
megatron/core/dist_checkpointing/strategies/base.py
+68
-0
megatron/core/dist_checkpointing/strategies/tensorstore.py
megatron/core/dist_checkpointing/strategies/tensorstore.py
+110
-0
megatron/core/dist_checkpointing/strategies/two_stage.py
megatron/core/dist_checkpointing/strategies/two_stage.py
+249
-0
megatron/core/dist_checkpointing/strategies/zarr.py
megatron/core/dist_checkpointing/strategies/zarr.py
+230
-0
megatron/core/dist_checkpointing/utils.py
megatron/core/dist_checkpointing/utils.py
+44
-0
megatron/core/enums.py
megatron/core/enums.py
+4
-1
megatron/core/fusions/__init__.py
megatron/core/fusions/__init__.py
+0
-0
megatron/core/fusions/fused_bias_dropout.py
megatron/core/fusions/fused_bias_dropout.py
+60
-0
megatron/core/fusions/fused_bias_gelu.py
megatron/core/fusions/fused_bias_gelu.py
+48
-0
megatron/core/fusions/fused_layer_norm.py
megatron/core/fusions/fused_layer_norm.py
+119
-0
megatron/core/fusions/fused_softmax.py
megatron/core/fusions/fused_softmax.py
+204
-0
megatron/core/inference_params.py
megatron/core/inference_params.py
+27
-0
megatron/core/model_parallel_config.py
megatron/core/model_parallel_config.py
+167
-0
megatron/core/models/__init__.py
megatron/core/models/__init__.py
+0
-0
megatron/core/models/common/__init__.py
megatron/core/models/common/__init__.py
+0
-0
megatron/core/models/common/rotary_pos_embedding.py
megatron/core/models/common/rotary_pos_embedding.py
+15
-15
megatron/core/models/gpt/__init__.py
megatron/core/models/gpt/__init__.py
+1
-0
megatron/core/models/gpt/gpt_embedding.py
megatron/core/models/gpt/gpt_embedding.py
+123
-0
megatron/core/models/gpt/gpt_model.py
megatron/core/models/gpt/gpt_model.py
+308
-0
No files found.
megatron/core/dist_checkpointing/strategies/__init__.py
0 → 100644
View file @
3aca1415
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Various loading and saving strategies """
import
logging
logger
=
logging
.
getLogger
(
__name__
)
try
:
import
tensorstore
import
zarr
from
.tensorstore
import
_import_trigger
from
.zarr
import
_import_trigger
except
ImportError
:
logger
.
warning
(
'Zarr-based strategies will not be registered because of missing packages'
)
megatron/core/dist_checkpointing/strategies/base.py
0 → 100644
View file @
3aca1415
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
from
abc
import
ABC
,
abstractmethod
from
collections
import
defaultdict
from
enum
import
Enum
from
pathlib
import
Path
from
typing
import
Dict
,
List
,
Optional
from
..mapping
import
CheckpointingException
,
ShardedStateDict
,
ShardedTensor
,
StateDict
class
StrategyAction
(
Enum
):
LOAD_COMMON
=
'load_common'
LOAD_SHARDED
=
'load_sharded'
SAVE_COMMON
=
'save_common'
SAVE_SHARDED
=
'save_sharded'
default_strategies
=
defaultdict
(
dict
)
def
get_default_strategy
(
action
:
StrategyAction
,
backend
:
str
,
version
:
int
):
try
:
return
default_strategies
[
action
.
value
][(
backend
,
version
)]
except
KeyError
as
e
:
raise
CheckpointingException
(
f
'Cannot find default strategy for:
{
(
action
,
backend
,
version
)
}
'
)
from
e
class
LoadStrategyBase
(
ABC
):
@
abstractmethod
def
check_backend_compatibility
(
self
,
loaded_version
):
raise
NotImplementedError
@
abstractmethod
def
check_version_compatibility
(
self
,
loaded_version
):
raise
NotImplementedError
class
SaveStrategyBase
(
ABC
):
def
__init__
(
self
,
backend
:
str
,
version
:
int
):
self
.
backend
=
backend
self
.
version
=
version
class
LoadCommonStrategy
(
LoadStrategyBase
):
@
abstractmethod
def
load
(
self
,
checkpoint_dir
:
Path
):
raise
NotImplementedError
class
LoadShardedStrategy
(
LoadStrategyBase
):
@
abstractmethod
def
load
(
self
,
sharded_state_dict
:
ShardedStateDict
,
checkpoint_dir
:
Path
):
raise
NotImplementedError
class
SaveCommonStrategy
(
SaveStrategyBase
):
@
abstractmethod
def
save
(
self
,
common_state_dict
:
StateDict
,
checkpoint_dir
:
Path
):
raise
NotImplementedError
class
SaveShardedStrategy
(
SaveStrategyBase
):
@
abstractmethod
def
save
(
self
,
sharded_tensors
:
List
[
ShardedTensor
],
checkpoint_dir
:
Path
):
raise
NotImplementedError
megatron/core/dist_checkpointing/strategies/tensorstore.py
0 → 100644
View file @
3aca1415
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Strategies using TensorStore to load and save Zarr arrays. """
from
functools
import
partial
from
itertools
import
starmap
from
pathlib
import
Path
import
tensorstore
as
ts
import
torch
from
..core
import
CheckpointingException
from
..dict_utils
import
dict_list_map_inplace
from
..mapping
import
ShardedStateDict
,
ShardedTensor
from
.base
import
LoadShardedStrategy
,
StrategyAction
,
default_strategies
from
.zarr
import
postprocess_numpy_array
_import_trigger
=
None
class
TensorStoreLoadShardedStrategy
(
LoadShardedStrategy
):
def
__init__
(
self
,
load_directly_on_device
:
bool
=
False
):
super
().
__init__
()
self
.
load_directly_on_device
=
load_directly_on_device
def
load
(
self
,
sharded_state_dict
:
ShardedStateDict
,
checkpoint_dir
:
Path
):
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
f
'Loading distributed checkpoint with
{
self
.
__class__
.
__name__
}
'
)
if
self
.
load_directly_on_device
:
print
(
f
'Loading distributed checkpoint directly on the GPU'
)
load_fn
=
partial
(
_load_from_array
,
checkpoint_dir
=
checkpoint_dir
,
load_directly_on_device
=
self
.
load_directly_on_device
,
)
dict_list_map_inplace
(
load_fn
,
sharded_state_dict
)
return
sharded_state_dict
def
check_backend_compatibility
(
self
,
loaded_version
):
pass
# TODO
def
check_version_compatibility
(
self
,
loaded_version
):
pass
# TODO
def
merge_global_slice_with_shape
(
global_slice
,
actual_shape
,
key
):
def
_merge_slice
(
dim_slice
,
dim_size
):
if
isinstance
(
dim_slice
,
slice
):
assert
(
dim_slice
.
start
<
dim_size
),
f
'Got empty slice for ShardedTensor
{
key
}
(
{
dim_slice
}
,
{
dim_size
}
)'
if
dim_slice
.
stop
>
dim_size
:
dim_slice
=
slice
(
dim_slice
.
start
,
dim_size
,
dim_slice
.
step
)
return
dim_slice
assert
len
(
global_slice
)
==
len
(
actual_shape
),
(
global_slice
,
actual_shape
,
key
)
return
tuple
(
starmap
(
_merge_slice
,
zip
(
global_slice
,
actual_shape
)))
def
_load_from_array
(
sharded_tensor
:
ShardedTensor
,
checkpoint_dir
:
Path
,
load_directly_on_device
:
bool
=
False
,
apply_flattened_range
:
bool
=
True
,
):
x
=
_load_regular_chunk
(
sharded_tensor
,
checkpoint_dir
)
ten
=
postprocess_numpy_array
(
x
,
sharded_tensor
,
apply_flattened_range
)
if
load_directly_on_device
:
sharded_tensor
.
data
.
data
.
copy_
(
ten
)
return
sharded_tensor
.
data
else
:
return
ten
def
_load_regular_chunk
(
sharded_tensor
:
ShardedTensor
,
checkpoint_dir
:
Path
):
assert
isinstance
(
sharded_tensor
,
ShardedTensor
),
type
(
sharded_tensor
)
spec
=
{
'driver'
:
'zarr'
,
'metadata_key'
:
'.zarray'
,
'kvstore'
:
{}}
spec
[
'kvstore'
]
=
{
'driver'
:
'file'
,
'path'
:
str
(
checkpoint_dir
/
sharded_tensor
.
key
),
}
try
:
arr
=
ts
.
open
(
ts
.
Spec
(
spec
),
open
=
True
).
result
()
except
Exception
as
e
:
raise
CheckpointingException
(
f
'Array
{
checkpoint_dir
/
sharded_tensor
.
key
}
could not be loaded. Error:
{
e
}
'
)
from
e
if
sharded_tensor
.
global_shape
==
arr
.
shape
:
x
=
(
arr
[
sharded_tensor
.
global_slice
()].
read
().
result
()
)
# flattened tensors loading is delayed
elif
sharded_tensor
.
allow_shape_mismatch
:
global_slice
=
merge_global_slice_with_shape
(
sharded_tensor
.
global_slice
(),
arr
.
shape
,
sharded_tensor
.
key
)
x
=
arr
[
global_slice
].
read
().
result
()
# flattened tensors loading is delayed
else
:
_msg
=
(
f
'Global shape mismatch for loaded (
{
arr
.
shape
}
)'
f
' and expected (
{
sharded_tensor
.
global_shape
}
) tensor'
f
' for key
{
sharded_tensor
.
key
}
'
)
raise
CheckpointingException
(
_msg
)
return
x
default_strategies
[
StrategyAction
.
LOAD_SHARDED
.
value
][
(
'zarr'
,
1
)
]
=
TensorStoreLoadShardedStrategy
()
megatron/core/dist_checkpointing/strategies/two_stage.py
0 → 100644
View file @
3aca1415
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" 2-stage checkpoint loading. """
import
os
import
time
from
collections
import
defaultdict
from
dataclasses
import
dataclass
from
functools
import
partial
,
wraps
from
itertools
import
chain
from
logging
import
DEBUG
,
INFO
,
StreamHandler
,
getLogger
from
operator
import
attrgetter
,
itemgetter
from
pathlib
import
Path
from
typing
import
Iterable
,
List
,
NamedTuple
,
Optional
,
Tuple
,
Union
import
torch
from
..dict_utils
import
dict_list_map_inplace
,
map_reduce
,
nested_values
from
..mapping
import
ShardedStateDict
,
ShardedTensor
,
StateDict
from
.base
import
LoadShardedStrategy
from
.tensorstore
import
_load_from_array
from
.zarr
import
flatten_range
_import_trigger
=
None
timers
=
defaultdict
(
list
)
logger
=
getLogger
(
__name__
)
def
timed
(
verbose
=
True
):
def
timed_dec
(
fn
):
name
=
fn
.
__name__
@
wraps
(
fn
)
def
wrapped
(
*
args
,
**
kwargs
):
if
verbose
:
logger
.
debug
(
f
'
{
name
}
init'
)
start
=
time
.
time
()
ret
=
fn
(
*
args
,
**
kwargs
)
took
=
time
.
time
()
-
start
if
verbose
:
logger
.
debug
(
f
'
{
name
}
took
{
took
}
s'
)
timers
[
name
].
append
(
took
)
return
ret
return
wrapped
return
timed_dec
@
dataclass
class
_ShardedTensorMetadata
:
global_rank
:
int
sharded_tensor_no_data
:
ShardedTensor
dist_group_rank
:
Tuple
[
int
]
# id of distributed group
dist_group_ranks
:
Tuple
[
int
]
# id of distributed group
data_size
:
Optional
[
int
]
=
None
# bytes
def
sharded_tensor_chunk_id
(
sharded_tensor
:
ShardedTensor
):
return
(
sharded_tensor
.
key
,
sharded_tensor
.
global_offset
,
)
class
TwoStageDataParallelLoadShardedStrategy
(
LoadShardedStrategy
):
""" Loads one checkpoint replica from storage and broadcasts to other nodes.
This strategy loads checkpoint from storage on minimal set of nodes
and distributes the checkpoint to other nodes with torch.distributed.
Loading is performed with tensorstore.
Steps:
0. (optional) create Gloo distributed groups
1. Exchange ShardedTensors metadata between all nodes
2. Align needed tensors within DP groups
3. For each globally unique tensor:
a) on one of the ranks load it from storage to CPU and move to CUDA
b) allocate CUDA tensor on other ranks
c) broadcast within DP group
d) copy tensor content to the model param location
e) free tensor buffers from a) and b)
Notes:
1. Loading and broadcasting is done sequentially to avoid both host and device OOMs
2. There is a lot of overlap potential between all three steps done for each tensor:
a) loading from storage to numpy
b) moving CPU tensors to CUDA
c) broadcast
"""
def
__init__
(
self
,
data_parallel_group
,
cpu_transfer
=
True
):
super
().
__init__
()
self
.
cpu_transfer
=
cpu_transfer
self
.
data_parallel_group_orig
=
data_parallel_group
self
.
data_parallel_group
=
None
if
cpu_transfer
else
data_parallel_group
self
.
dp_group_ranks
=
tuple
(
sorted
(
torch
.
distributed
.
get_process_group_ranks
(
data_parallel_group
))
)
self
.
dp_group_rank
=
torch
.
distributed
.
get_rank
(
self
.
data_parallel_group_orig
)
self
.
global_rank
=
torch
.
distributed
.
get_rank
()
def
load
(
self
,
sharded_state_dict
:
ShardedStateDict
,
checkpoint_dir
:
Path
):
self
.
maybe_init_gloo_group
()
all_tensors_sorted
=
self
.
_build_load_plan
(
sharded_state_dict
)
self
.
_exchange_loaded_tensors
(
all_tensors_sorted
,
sharded_state_dict
,
checkpoint_dir
)
self
.
summarize_load_times
()
return
sharded_state_dict
def
summarize_load_times
(
self
):
torch
.
distributed
.
barrier
()
logger
.
info
(
'Checkpoint loading finished. Summary:'
)
for
key
,
times
in
sorted
(
timers
.
items
()):
times_sum
=
sum
(
times
)
max_times
=
torch
.
tensor
([
times_sum
],
device
=
'cuda'
)
avg_times
=
torch
.
tensor
([
times_sum
],
device
=
'cuda'
)
torch
.
distributed
.
all_reduce
(
max_times
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
)
torch
.
distributed
.
all_reduce
(
avg_times
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
)
avg_times
/=
torch
.
distributed
.
get_world_size
()
if
torch
.
distributed
.
get_rank
()
==
0
:
logger
.
info
(
f
'
{
key
}
: max
{
max_times
[
0
]
}
, avg
{
avg_times
[
0
]
}
'
)
@
timed
(
verbose
=
False
)
def
load_tensor_from_storage
(
self
,
checkpoint_dir
,
ten_meta
:
_ShardedTensorMetadata
):
logger
.
debug
(
f
'_load_from_array(
{
ten_meta
.
sharded_tensor_no_data
.
key
}
) init'
)
ret
=
_load_from_array
(
ten_meta
.
sharded_tensor_no_data
,
checkpoint_dir
,
load_directly_on_device
=
False
,
apply_flattened_range
=
False
,
)
logger
.
debug
(
f
'_load_from_array(
{
ten_meta
.
sharded_tensor_no_data
.
key
}
) DONE'
)
return
ret
@
timed
()
def
maybe_init_gloo_group
(
self
):
if
not
self
.
cpu_transfer
:
return
all_groups
=
[
None
]
*
torch
.
distributed
.
get_world_size
()
torch
.
distributed
.
all_gather_object
(
all_groups
,
self
.
dp_group_ranks
)
all_groups
=
set
(
tuple
(
sorted
(
gr
))
for
gr
in
all_groups
)
for
group_ranks
in
sorted
(
all_groups
):
gloo_pg
=
torch
.
distributed
.
new_group
(
ranks
=
group_ranks
,
backend
=
'gloo'
)
if
self
.
global_rank
in
group_ranks
:
self
.
data_parallel_group
=
gloo_pg
assert
self
.
dp_group_rank
==
torch
.
distributed
.
get_rank
(
self
.
data_parallel_group
)
def
check_backend_compatibility
(
self
,
loaded_version
):
pass
# TODO
def
check_version_compatibility
(
self
,
loaded_version
):
pass
# TODO
@
timed
()
def
_build_load_plan
(
self
,
sharded_state_dict
:
ShardedStateDict
)
->
List
[
_ShardedTensorMetadata
]:
local_meta
=
[
_ShardedTensorMetadata
(
self
.
global_rank
,
sharded_ten
.
without_data
(),
self
.
dp_group_rank
,
self
.
dp_group_ranks
,
)
for
sharded_ten
in
nested_values
(
sharded_state_dict
)
]
all_meta
=
[
None
]
*
torch
.
distributed
.
get_world_size
(
group
=
self
.
data_parallel_group
)
torch
.
distributed
.
all_gather_object
(
all_meta
,
local_meta
,
group
=
self
.
data_parallel_group
)
all_meta
=
list
(
chain
.
from_iterable
(
all_meta
))
all_tensors_sorted
=
self
.
deduplicate_chunks
(
all_meta
)
return
all_tensors_sorted
@
timed
()
def
deduplicate_chunks
(
self
,
ten_metas
:
List
[
_ShardedTensorMetadata
]):
""" Group tensors by chunk and then pick the tensor with the lowest rank.
NOTE: with proper loading overlap, loading from randomized ranks
(instead of the smallest one) could be beneficial here.
"""
ten_metas
=
map_reduce
(
ten_metas
,
key_fn
=
lambda
meta
:
sharded_tensor_chunk_id
(
meta
.
sharded_tensor_no_data
),
reduce_fn
=
partial
(
min
,
key
=
attrgetter
(
'dist_group_rank'
)),
)
all_metas_sorted
=
list
(
map
(
itemgetter
(
1
),
sorted
(
ten_metas
.
items
())))
return
all_metas_sorted
@
timed
()
def
_exchange_loaded_tensors
(
self
,
ten_metas
:
List
[
_ShardedTensorMetadata
],
sharded_state_dict
,
checkpoint_dir
):
logger
.
debug
(
f
'_exchange_loaded_tensors, num ten_metas:
{
len
(
ten_metas
)
}
'
)
for
ten_meta
in
ten_metas
:
src_rank
=
torch
.
distributed
.
get_global_rank
(
self
.
data_parallel_group
,
ten_meta
.
dist_group_rank
)
if
self
.
dp_group_rank
==
ten_meta
.
dist_group_rank
:
exchange_tensor
=
self
.
load_tensor_from_storage
(
checkpoint_dir
,
ten_meta
)
if
not
self
.
cpu_transfer
:
exchange_tensor
=
exchange_tensor
.
cuda
()
else
:
# TODO: for non-flattened ranges we could reuse the buffer from the start here
exchange_tensor
=
torch
.
empty
(
ten_meta
.
sharded_tensor_no_data
.
local_shape
,
device
=
'cpu'
if
self
.
cpu_transfer
else
'cuda'
,
dtype
=
ten_meta
.
sharded_tensor_no_data
.
dtype
,
)
logger
.
debug
(
f
'exchange
{
ten_meta
.
sharded_tensor_no_data
.
key
}
,
{
exchange_tensor
.
shape
}
(
{
exchange_tensor
.
numel
()
}
), broadcast(
{
src_rank
}
->
{
self
.
dp_group_ranks
}
)'
)
torch
.
distributed
.
broadcast
(
exchange_tensor
,
group
=
self
.
data_parallel_group
,
src
=
src_rank
)
self
.
_distribute_data_to_state_dict
(
ten_meta
,
exchange_tensor
,
sharded_state_dict
)
logger
.
debug
(
f
'exchange
{
ten_meta
.
sharded_tensor_no_data
.
key
}
done'
)
# free buffer memory
exchange_tensor
=
None
@
timed
(
verbose
=
False
)
def
_distribute_data_to_state_dict
(
self
,
ten_meta
:
_ShardedTensorMetadata
,
loaded_ten
:
torch
.
Tensor
,
sharded_state_dict
:
ShardedStateDict
,
):
tensor_key
=
sharded_tensor_chunk_id
(
ten_meta
.
sharded_tensor_no_data
)
def
_fill_in_data
(
t
:
Union
[
ShardedTensor
,
torch
.
Tensor
]):
if
not
isinstance
(
t
,
ShardedTensor
)
or
sharded_tensor_chunk_id
(
t
)
!=
tensor_key
:
# already filled-in or key not matching
return
t
sharded_tensor
:
ShardedTensor
=
t
x
=
loaded_ten
if
sharded_tensor
.
flattened_range
is
not
None
:
x
=
flatten_range
(
sharded_tensor
,
x
)
# Reuse existing buffer
sharded_tensor
.
data
.
data
.
copy_
(
x
)
return
sharded_tensor
.
data
dict_list_map_inplace
(
_fill_in_data
,
sharded_state_dict
)
megatron/core/dist_checkpointing/strategies/zarr.py
0 → 100644
View file @
3aca1415
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Strategies using Zarr as an underlying format. """
import
os
from
functools
import
partial
from
pathlib
import
Path
from
typing
import
List
import
numpy
as
np
import
torch
import
zarr
from
..core
import
CheckpointingException
from
..dict_utils
import
dict_list_map_inplace
from
..mapping
import
ShardedStateDict
,
ShardedTensor
,
is_main_replica
from
.base
import
LoadShardedStrategy
,
SaveShardedStrategy
,
StrategyAction
,
default_strategies
numpy_to_torch_dtype_dict
=
{
np
.
bool_
:
torch
.
bool
,
np
.
uint8
:
torch
.
uint8
,
np
.
int8
:
torch
.
int8
,
np
.
int16
:
torch
.
int16
,
np
.
int32
:
torch
.
int32
,
np
.
int64
:
torch
.
int64
,
np
.
float16
:
torch
.
float16
,
np
.
float32
:
torch
.
float32
,
np
.
float64
:
torch
.
float64
,
np
.
complex64
:
torch
.
complex64
,
np
.
complex128
:
torch
.
complex128
,
}
torch_to_numpy_dtype_dict
=
{
v
:
k
for
k
,
v
in
numpy_to_torch_dtype_dict
.
items
()}
try
:
import
tensorstore
HAS_BFLOAT16
=
True
numpy_to_torch_dtype_dict
[
np
.
dtype
(
'bfloat16'
)]
=
torch
.
bfloat16
torch_to_numpy_dtype_dict
[
torch
.
bfloat16
]
=
np
.
dtype
(
'bfloat16'
)
except
ImportError
:
HAS_BFLOAT16
=
False
_import_trigger
=
None
class
ZarrSaveShardedStrategy
(
SaveShardedStrategy
):
def
save
(
self
,
sharded_tensors
:
List
[
ShardedTensor
],
checkpoint_dir
:
Path
):
arrays
=
_create_or_open_zarr_arrays
(
sharded_tensors
,
checkpoint_dir
)
for
ten
,
arr
in
zip
(
sharded_tensors
,
arrays
):
_save_to_existing_array
(
ten
,
arr
)
torch
.
distributed
.
barrier
()
def
_create_or_open_zarr_arrays
(
sharded_tensors
:
List
[
ShardedTensor
],
checkpoint_dir
:
Path
)
->
List
[
zarr
.
Array
]:
arrays
=
[]
for
ten
in
sharded_tensors
:
if
_should_create_array
(
ten
):
_create_zarr_array
(
ten
,
checkpoint_dir
)
# TODO: maybe reuse the opened arrays
torch
.
distributed
.
barrier
()
for
ten
in
sharded_tensors
:
# if is_main_replica(ten.replica_id) and set(ten.global_offset) == {0}:
# continue
open_kwargs
=
{}
if
ten
.
flattened_range
is
not
None
:
open_kwargs
[
'synchronizer'
]
=
zarr
.
ProcessSynchronizer
(
str
(
checkpoint_dir
/
f
'
{
ten
.
key
}
.sync'
)
)
arr
=
zarr
.
open
(
checkpoint_dir
/
ten
.
key
,
'r+'
,
**
open_kwargs
)
arrays
.
append
(
arr
)
return
arrays
def
_should_create_array
(
ten
:
ShardedTensor
):
return
(
is_main_replica
(
ten
.
replica_id
)
and
set
(
ten
.
global_offset
)
==
{
0
}
and
(
ten
.
flattened_range
is
None
or
ten
.
flattened_range
.
start
==
0
)
)
def
_save_to_existing_array
(
sharded_tensor
:
ShardedTensor
,
arr
:
zarr
.
Array
):
if
not
is_main_replica
(
sharded_tensor
.
replica_id
):
return
x
=
sharded_tensor
.
data
x
=
x
.
detach
().
cpu
()
torch
.
cuda
.
synchronize
()
if
x
.
dtype
==
torch
.
bfloat16
:
x
=
x
.
float
()
x
=
x
.
numpy
()
x
=
x
.
astype
(
'bfloat16'
)
else
:
x
=
x
.
numpy
()
if
sharded_tensor
.
flattened_range
is
None
:
arr
[
sharded_tensor
.
global_slice
()]
=
x
else
:
arr
.
set_coordinate_selection
(
sharded_tensor
.
global_coordinates
(),
x
)
def
_create_zarr_array
(
sharded_tensor
:
ShardedTensor
,
checkpoint_dir
:
Path
):
np_dtype
=
torch_to_numpy_dtype_dict
[
sharded_tensor
.
dtype
]
try
:
arr
=
zarr
.
create
(
sharded_tensor
.
global_shape
,
dtype
=
np_dtype
,
store
=
checkpoint_dir
/
sharded_tensor
.
key
,
chunks
=
sharded_tensor
.
max_allowed_chunks
(),
compressor
=
None
,
fill_value
=
None
,
write_empty_chunks
=
True
,
)
except
zarr
.
errors
.
ContainsArrayError
as
e
:
raise
CheckpointingException
(
f
'Array
{
checkpoint_dir
/
sharded_tensor
.
key
}
already exists'
)
from
e
if
HAS_BFLOAT16
and
np_dtype
==
np
.
dtype
(
'bfloat16'
):
arr
.
_dtype
=
np_dtype
zarray
=
arr
.
store
[
'.zarray'
]
arr
.
store
[
'.zarray'
]
=
zarray
.
replace
(
b
'<V2'
,
b
'bfloat16'
)
return
arr
class
ZarrLoadShardedStrategy
(
LoadShardedStrategy
):
def
load
(
self
,
sharded_state_dict
:
ShardedStateDict
,
checkpoint_dir
:
Path
):
dict_list_map_inplace
(
partial
(
_load_from_array
,
checkpoint_dir
=
checkpoint_dir
),
sharded_state_dict
)
return
sharded_state_dict
def
check_backend_compatibility
(
self
,
loaded_version
):
pass
# TODO
def
check_version_compatibility
(
self
,
loaded_version
):
pass
# TODO
def
_load_from_array
(
sharded_tensor
:
ShardedTensor
,
checkpoint_dir
:
Path
):
assert
isinstance
(
sharded_tensor
,
ShardedTensor
),
type
(
sharded_tensor
)
try
:
arr
=
zarr
.
open
(
checkpoint_dir
/
sharded_tensor
.
key
,
'r'
)
except
zarr
.
errors
.
PathNotFoundError
as
e
:
raise
CheckpointingException
(
f
'Array
{
checkpoint_dir
/
sharded_tensor
.
key
}
not found'
)
from
e
if
not
sharded_tensor
.
allow_shape_mismatch
and
sharded_tensor
.
global_shape
!=
arr
.
shape
:
_msg
=
(
f
'Global shape mismatch for loaded (
{
arr
.
shape
}
)'
f
' and expected (
{
sharded_tensor
.
global_shape
}
) tensor'
f
' for key
{
sharded_tensor
.
key
}
'
)
raise
CheckpointingException
(
_msg
)
x
=
arr
[
sharded_tensor
.
global_slice
()]
# flattened tensors loading is delayed
return
postprocess_numpy_array
(
x
,
sharded_tensor
)
def
postprocess_numpy_array
(
loaded_array
,
sharded_tensor
,
apply_flattened_range
=
True
):
x
=
loaded_array
if
HAS_BFLOAT16
and
x
.
dtype
==
np
.
dtype
(
'bfloat16'
):
x
=
x
.
astype
(
np
.
dtype
(
'float32'
))
x
=
torch
.
from_numpy
(
x
)
x
=
x
.
bfloat16
()
else
:
x
=
torch
.
from_numpy
(
x
)
# TODO: consider some other consistency checks
if
x
.
shape
!=
sharded_tensor
.
local_shape
:
if
sharded_tensor
.
allow_shape_mismatch
:
x
=
pad_to_expected_shape
(
x
,
sharded_tensor
)
else
:
_msg
=
(
f
'Local shape mismatch for loaded (
{
x
.
shape
}
)'
f
' and expected (
{
sharded_tensor
.
local_shape
}
) tensor'
f
' for key
{
sharded_tensor
.
key
}
'
)
raise
CheckpointingException
(
_msg
)
if
apply_flattened_range
and
sharded_tensor
.
flattened_range
is
not
None
:
x
=
flatten_range
(
sharded_tensor
,
x
)
# TODO: consider cuda() tensors support
return
x
def
flatten_range
(
sharded_tensor
,
x
):
return
x
.
flatten
()[
sharded_tensor
.
flattened_range
]
def
pad_to_expected_shape
(
x
:
torch
.
Tensor
,
expected_sharded_ten
:
ShardedTensor
):
pad_args
=
[]
assert
len
(
x
.
shape
)
==
len
(
expected_sharded_ten
.
local_shape
)
# Reversed iteration order because F.pad expects so
for
x_sh
,
exp_sh
,
axis_fragm
in
reversed
(
list
(
zip
(
x
.
shape
,
expected_sharded_ten
.
local_shape
,
expected_sharded_ten
.
axis_fragmentations
)
)
):
if
x_sh
==
exp_sh
:
pad_args
.
extend
((
0
,
0
))
elif
x_sh
>
exp_sh
:
assert
(
False
),
f
'Expected shape (
{
exp_sh
}
) smaller than actual (
{
x_sh
}
) for
{
repr
(
expected_sharded_ten
)
}
'
else
:
pad_args
.
extend
((
0
,
exp_sh
-
x_sh
))
# TODO: behavior control with envvar is for testing purposes only, remove it
if
not
int
(
os
.
environ
.
get
(
'DIST_CKPT_PAD_REPLICATE'
,
0
)):
return
torch
.
nn
.
functional
.
pad
(
x
,
pad_args
)
# unsqueeze and squeeze to get shapes supported by cudnn
print
(
f
'Replicating last row for
{
expected_sharded_ten
.
key
}
'
)
if
x
.
dtype
==
torch
.
bfloat16
:
return
(
torch
.
nn
.
functional
.
pad
(
x
.
float
().
unsqueeze
(
0
),
pad_args
,
mode
=
'replicate'
)
.
squeeze
(
0
)
.
bfloat16
()
)
return
torch
.
nn
.
functional
.
pad
(
x
.
unsqueeze
(
0
),
pad_args
,
mode
=
'replicate'
).
squeeze
(
0
)
# default_strategies[StrategyAction.LOAD_SHARDED.value][('zarr', 1)] = ZarrLoadShardedStrategy()
default_strategies
[
StrategyAction
.
SAVE_SHARDED
.
value
][(
'zarr'
,
1
)]
=
ZarrSaveShardedStrategy
(
'zarr'
,
1
)
megatron/core/dist_checkpointing/utils.py
0 → 100644
View file @
3aca1415
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
from
typing
import
Tuple
from
.dict_utils
import
dict_list_map_inplace
,
extract_matching_values
from
.mapping
import
(
LocalNonpersitentObject
,
ShardedStateDict
,
ShardedTensor
,
ShardedTensorFactory
,
StateDict
,
)
def
extract_sharded_tensors
(
sharded_state_dict
:
ShardedStateDict
,
)
->
Tuple
[
ShardedStateDict
,
StateDict
]:
return
extract_matching_values
(
sharded_state_dict
,
lambda
v
:
isinstance
(
v
,
ShardedTensor
))
def
extract_sharded_tensors_and_factories
(
sharded_state_dict
:
ShardedStateDict
,
)
->
Tuple
[
ShardedStateDict
,
StateDict
]:
return
extract_matching_values
(
sharded_state_dict
,
lambda
v
:
isinstance
(
v
,
(
ShardedTensor
,
ShardedTensorFactory
))
)
def
extract_sharded_tensors_or_nonpersistent
(
sharded_state_dict
:
ShardedStateDict
,
)
->
Tuple
[
ShardedStateDict
,
StateDict
]:
return
extract_matching_values
(
sharded_state_dict
,
lambda
v
:
isinstance
(
v
,
(
ShardedTensor
,
LocalNonpersitentObject
,
ShardedTensorFactory
)),
)
def
add_prefix_for_sharding
(
sharded_state_dict
:
ShardedStateDict
,
prefix
:
str
):
def
add_prefix
(
t
):
if
isinstance
(
t
,
ShardedTensor
):
t
.
key
=
f
'
{
prefix
}
.
{
t
.
key
}
'
return
t
dict_list_map_inplace
(
add_prefix
,
sharded_state_dict
)
megatron/core/enums.py
View file @
3aca1415
# Copyright (c) 202
2
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 202
3
, NVIDIA CORPORATION. All rights reserved.
import
enum
import
enum
class
ModelType
(
enum
.
Enum
):
class
ModelType
(
enum
.
Enum
):
encoder_or_decoder
=
1
encoder_or_decoder
=
1
encoder_and_decoder
=
2
encoder_and_decoder
=
2
retro_encoder
=
3
retro_decoder
=
4
tests/pipeline_parallel
/__init__.py
→
megatron/core/fusions
/__init__.py
View file @
3aca1415
File moved
megatron/core/fusions/fused_bias_dropout.py
0 → 100644
View file @
3aca1415
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from
typing
import
Optional
,
Tuple
import
torch
def
_bias_dropout_add_func
(
x
,
bias
,
residual
,
prob
,
training
):
# type: (Tensor, Optional[Tensor], Tensor, float, bool) -> Tensor
# NOTE: Previously, the argument `bias` used to be passed as
# `bias.expand_as(residual)` when the `bias_dropout_func` is called from the
# transformer layer but broadcasting should automatically take care of that.
# Also, looking at broadcasting semantics, `expand_as` and broadcasting
# seem to be identical performance-wise (both just change the view).
# If we want to train mixed precision, then the output of this function
# should be half precision. However, in AMP O1, the input (residual) is
# in fp32, and it will up-cast the result to fp32, causing pipeline parallel
# GPU communication to hang. Therefore, we need to cast residual to the same
# dtype as x.
residual
=
residual
if
residual
.
dtype
==
x
.
dtype
else
residual
.
to
(
x
.
dtype
)
if
bias
is
not
None
:
x
=
x
+
bias
out
=
torch
.
nn
.
functional
.
dropout
(
x
,
p
=
prob
,
training
=
training
)
out
=
residual
+
out
return
out
@
torch
.
jit
.
script
def
bias_dropout_add_fused_train
(
x_with_bias
:
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]],
residual
:
torch
.
Tensor
,
prob
:
float
,
)
->
torch
.
Tensor
:
x
,
bias
=
x_with_bias
# unpack
return
_bias_dropout_add_func
(
x
,
bias
,
residual
,
prob
,
True
)
@
torch
.
jit
.
script
def
bias_dropout_add_fused_inference
(
x_with_bias
:
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]],
residual
:
torch
.
Tensor
,
prob
:
float
,
)
->
torch
.
Tensor
:
x
,
bias
=
x_with_bias
# unpack
return
_bias_dropout_add_func
(
x
,
bias
,
residual
,
prob
,
False
)
def
get_bias_dropout_add
(
training
,
fused
):
def
unfused_bias_dropout_add
(
x_with_bias
,
residual
,
prob
):
x
,
bias
=
x_with_bias
# unpack
return
_bias_dropout_add_func
(
x
,
bias
,
residual
,
prob
,
training
)
if
fused
:
# jit scripting for a nn.module (with dropout) is not
# triggering the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if
training
:
return
bias_dropout_add_fused_train
else
:
return
bias_dropout_add_fused_inference
else
:
return
unfused_bias_dropout_add
megatron/core/fusions/fused_bias_gelu.py
0 → 100644
View file @
3aca1415
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import
torch
###### BIAS GELU FUSION/ NO AUTOGRAD ################
# 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2) -> 0.70710678
# sqrt(2/pi) -> 0.79788456
# this function is tanh approximation of gelu
# actual gelu is:
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@
torch
.
jit
.
script
def
bias_gelu
(
bias
,
y
):
x
=
bias
+
y
return
x
*
0.5
*
(
1.0
+
torch
.
tanh
(
0.79788456
*
x
*
(
1
+
0.044715
*
x
*
x
)))
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@
torch
.
jit
.
script
def
bias_gelu_back
(
g
,
bias
,
y
):
x
=
bias
+
y
tanh_out
=
torch
.
tanh
(
0.79788456
*
x
*
(
1
+
0.044715
*
x
*
x
))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff
=
0.5
*
x
*
((
1
-
tanh_out
*
tanh_out
)
*
(
0.79788456
+
0.1070322243
*
x
*
x
))
+
0.5
*
(
1
+
tanh_out
)
return
ff
*
g
class
GeLUFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
# bias is an optional argument
def
forward
(
ctx
,
input
,
bias
):
ctx
.
save_for_backward
(
input
,
bias
)
return
bias_gelu
(
bias
,
input
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
input
,
bias
=
ctx
.
saved_tensors
tmp
=
bias_gelu_back
(
grad_output
,
bias
,
input
)
return
tmp
,
tmp
bias_gelu_impl
=
GeLUFunction
.
apply
megatron/core/fusions/fused_layer_norm.py
0 → 100644
View file @
3aca1415
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import
importlib
import
numbers
import
torch
from
torch.nn
import
init
from
torch.nn.parameter
import
Parameter
from
megatron.core.utils
import
make_viewless_tensor
try
:
from
apex.contrib.layer_norm.layer_norm
import
FastLayerNormFN
HAVE_PERSIST_LAYER_NORM
=
True
except
:
HAVE_PERSIST_LAYER_NORM
=
False
try
:
from
apex.normalization.fused_layer_norm
import
FusedLayerNormAffineFunction
HAVE_FUSED_LAYER_NORM
=
True
except
:
HAVE_FUSED_LAYER_NORM
=
False
class
FusedLayerNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-5
,
persist_layer_norm
=
True
,
sequence_parallel
=
False
,
zero_centered_gamma
=
False
,
):
super
().
__init__
()
self
.
zero_centered_gamma
=
zero_centered_gamma
# List of hiddens sizes supported in the persistent layer norm kernel
# If the hidden size is not supported, fall back to the non-persistent
# kernel.
persist_ln_hidden_sizes
=
[
1024
,
1536
,
2048
,
2304
,
3072
,
3840
,
4096
,
5120
,
6144
,
8192
,
10240
,
12288
,
12800
,
15360
,
16384
,
18432
,
20480
,
24576
,
25600
,
30720
,
32768
,
40960
,
49152
,
65536
,
]
if
hidden_size
not
in
persist_ln_hidden_sizes
or
not
HAVE_PERSIST_LAYER_NORM
:
persist_layer_norm
=
False
if
not
persist_layer_norm
and
not
HAVE_FUSED_LAYER_NORM
:
# TODO: Add pytorch only layer norm
raise
ValueError
(
f
'Apex must currently be installed to use megatron core.'
)
if
isinstance
(
hidden_size
,
numbers
.
Integral
):
hidden_size
=
(
hidden_size
,)
self
.
hidden_size
=
torch
.
Size
(
hidden_size
)
self
.
eps
=
eps
self
.
weight
=
Parameter
(
torch
.
Tensor
(
*
hidden_size
))
self
.
bias
=
Parameter
(
torch
.
Tensor
(
*
hidden_size
))
self
.
reset_parameters
()
self
.
persist_layer_norm
=
persist_layer_norm
self
.
sequence_parallel
=
sequence_parallel
# set sequence parallelism flag on weight and bias parameters
setattr
(
self
.
weight
,
'sequence_parallel'
,
self
.
sequence_parallel
)
setattr
(
self
.
bias
,
'sequence_parallel'
,
self
.
sequence_parallel
)
def
reset_parameters
(
self
):
if
self
.
zero_centered_gamma
:
init
.
zeros_
(
self
.
weight
)
init
.
zeros_
(
self
.
bias
)
else
:
init
.
ones_
(
self
.
weight
)
init
.
zeros_
(
self
.
bias
)
def
forward
(
self
,
input
):
weight
=
self
.
weight
+
1
if
self
.
zero_centered_gamma
else
self
.
weight
if
self
.
persist_layer_norm
:
output
=
FastLayerNormFN
.
apply
(
input
,
weight
,
self
.
bias
,
self
.
eps
)
# Apex's fast layer norm function outputs a 'view' tensor (i.e., has
# a populated '_base' field). This will result in schedule.py's
# deallocate_output_tensor() throwing an error, so a viewless tensor is
# created to prevent this.
output
=
make_viewless_tensor
(
inp
=
output
,
requires_grad
=
input
.
requires_grad
,
keep_graph
=
True
)
else
:
output
=
FusedLayerNormAffineFunction
.
apply
(
input
,
weight
,
self
.
bias
,
self
.
hidden_size
,
self
.
eps
)
return
output
megatron/core/fusions/fused_softmax.py
0 → 100644
View file @
3aca1415
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import
torch
import
torch.nn
as
nn
from
megatron.core.transformer.enums
import
AttnMaskType
class
ScaledUpperTriangMaskedSoftmax
(
torch
.
autograd
.
Function
):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply upper triangular mask (typically used in gpt models).
3. Perform softmax.
"""
@
staticmethod
def
forward
(
ctx
,
inputs
,
scale
):
import
scaled_upper_triang_masked_softmax_cuda
scale_t
=
torch
.
tensor
([
scale
])
softmax_results
=
scaled_upper_triang_masked_softmax_cuda
.
forward
(
inputs
,
scale_t
[
0
])
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
return
softmax_results
@
staticmethod
def
backward
(
ctx
,
output_grads
):
import
scaled_upper_triang_masked_softmax_cuda
softmax_results
,
scale_t
=
ctx
.
saved_tensors
input_grads
=
scaled_upper_triang_masked_softmax_cuda
.
backward
(
output_grads
,
softmax_results
,
scale_t
[
0
]
)
return
input_grads
,
None
class
ScaledMaskedSoftmax
(
torch
.
autograd
.
Function
):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply the mask.
3. Perform softmax.
"""
@
staticmethod
def
forward
(
ctx
,
inputs
,
mask
,
scale
):
import
scaled_masked_softmax_cuda
scale_t
=
torch
.
tensor
([
scale
])
softmax_results
=
scaled_masked_softmax_cuda
.
forward
(
inputs
,
mask
,
scale_t
[
0
])
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
return
softmax_results
@
staticmethod
def
backward
(
ctx
,
output_grads
):
import
scaled_masked_softmax_cuda
softmax_results
,
scale_t
=
ctx
.
saved_tensors
input_grads
=
scaled_masked_softmax_cuda
.
backward
(
output_grads
,
softmax_results
,
scale_t
[
0
])
return
input_grads
,
None
,
None
class
ScaledSoftmax
(
torch
.
autograd
.
Function
):
"""
Fused operation which performs following two operations in sequence
1. Scale the tensor.
2. Perform softmax.
"""
@
staticmethod
def
forward
(
ctx
,
inputs
,
scale
):
import
scaled_softmax_cuda
scale_t
=
torch
.
tensor
([
scale
])
softmax_results
=
scaled_softmax_cuda
.
forward
(
inputs
,
scale_t
[
0
])
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
return
softmax_results
@
staticmethod
def
backward
(
ctx
,
output_grads
):
import
scaled_softmax_cuda
softmax_results
,
scale_t
=
ctx
.
saved_tensors
input_grads
=
scaled_softmax_cuda
.
backward
(
output_grads
,
softmax_results
,
scale_t
[
0
])
return
input_grads
,
None
,
None
class
FusedScaleMaskSoftmax
(
nn
.
Module
):
"""
fused operation: scaling + mask + softmax
Arguments:
input_in_fp16: flag to indicate if input in fp16 data format.
input_in_bf16: flag to indicate if input in bf16 data format.
attn_mask_type: attention mask type (pad or causal)
scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion
mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
scale: scaling factor used in input tensor scaling.
"""
def
__init__
(
self
,
input_in_fp16
,
input_in_bf16
,
attn_mask_type
,
scaled_masked_softmax_fusion
,
mask_func
,
softmax_in_fp32
,
scale
,
):
super
(
FusedScaleMaskSoftmax
,
self
).
__init__
()
self
.
input_in_fp16
=
input_in_fp16
self
.
input_in_bf16
=
input_in_bf16
assert
not
(
self
.
input_in_fp16
and
self
.
input_in_bf16
),
"both fp16 and bf16 flags cannot be active at the same time."
self
.
input_in_float16
=
self
.
input_in_fp16
or
self
.
input_in_bf16
self
.
attn_mask_type
=
attn_mask_type
self
.
scaled_masked_softmax_fusion
=
scaled_masked_softmax_fusion
self
.
mask_func
=
mask_func
self
.
softmax_in_fp32
=
softmax_in_fp32
self
.
scale
=
scale
assert
self
.
scale
is
None
or
softmax_in_fp32
,
"softmax should be in fp32 when scaled"
def
forward
(
self
,
input
,
mask
):
# [b, np, sq, sk]
assert
input
.
dim
()
==
4
if
self
.
is_kernel_available
(
mask
,
*
input
.
size
()):
return
self
.
forward_fused_softmax
(
input
,
mask
)
else
:
return
self
.
forward_torch_softmax
(
input
,
mask
)
def
is_kernel_available
(
self
,
mask
,
b
,
np
,
sq
,
sk
):
attn_batches
=
b
*
np
if
(
self
.
scaled_masked_softmax_fusion
# user want to fuse
and
self
.
input_in_float16
# input must be fp16
and
16
<
sk
<=
4096
# sk must be 16 ~ 2048
and
sq
%
4
==
0
# sq must be divisor of 4
and
sk
%
4
==
0
# sk must be divisor of 4
and
attn_batches
%
4
==
0
# np * b must be divisor of 4
):
if
0
<=
sk
<=
4096
:
batch_per_block
=
self
.
get_batch_per_block
(
sq
,
sk
,
b
,
np
)
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
if
attn_batches
%
batch_per_block
==
0
:
return
True
else
:
if
sq
%
batch_per_block
==
0
:
return
True
return
False
def
forward_fused_softmax
(
self
,
input
,
mask
):
b
,
np
,
sq
,
sk
=
input
.
size
()
scale
=
self
.
scale
if
self
.
scale
is
not
None
else
1.0
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
assert
sq
==
sk
,
"causal mask is only for self attention"
# input is 3D tensor (attn_batches, sq, sk)
input
=
input
.
view
(
-
1
,
sq
,
sk
)
probs
=
ScaledUpperTriangMaskedSoftmax
.
apply
(
input
,
scale
)
return
probs
.
view
(
b
,
np
,
sq
,
sk
)
else
:
# input is 4D tensor (b, np, sq, sk)
if
mask
is
not
None
:
return
ScaledMaskedSoftmax
.
apply
(
input
,
mask
,
scale
)
else
:
return
ScaledSoftmax
.
apply
(
input
,
scale
)
def
forward_torch_softmax
(
self
,
input
,
mask
):
if
self
.
input_in_float16
and
self
.
softmax_in_fp32
:
input
=
input
.
float
()
if
self
.
scale
is
not
None
:
input
=
input
*
self
.
scale
mask_output
=
self
.
mask_func
(
input
,
mask
)
if
mask
is
not
None
else
input
probs
=
torch
.
nn
.
Softmax
(
dim
=-
1
)(
mask_output
)
if
self
.
input_in_float16
and
self
.
softmax_in_fp32
:
if
self
.
input_in_fp16
:
probs
=
probs
.
half
()
else
:
probs
=
probs
.
bfloat16
()
return
probs
@
staticmethod
def
get_batch_per_block
(
sq
,
sk
,
b
,
np
):
import
scaled_masked_softmax_cuda
return
scaled_masked_softmax_cuda
.
get_batch_per_block
(
sq
,
sk
,
b
,
np
)
megatron/core/inference_params.py
0 → 100644
View file @
3aca1415
class
InferenceParams
:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
def
__init__
(
self
,
max_batch_size
,
max_sequence_length
):
self
.
max_sequence_length
=
max_sequence_length
self
.
max_batch_size
=
max_batch_size
self
.
sequence_len_offset
=
0
self
.
batch_size_offset
=
0
self
.
key_value_memory_dict
=
{}
def
swap_key_value_dict
(
self
,
batch_idx
):
"swap between batches"
if
len
(
self
.
key_value_memory_dict
)
==
0
:
raise
ValueError
(
"should not swap when dict in empty"
)
for
layer_number
in
self
.
key_value_memory_dict
.
keys
():
inference_key_memory
,
inference_value_memory
=
self
.
key_value_memory_dict
[
layer_number
]
assert
(
len
(
batch_idx
)
==
inference_key_memory
.
shape
[
1
]
)
# make sure batch size is the same
new_inference_key_memory
=
inference_key_memory
[:,
batch_idx
]
new_inference_value_memory
=
inference_value_memory
[:,
batch_idx
]
self
.
key_value_memory_dict
[
layer_number
]
=
(
new_inference_key_memory
,
new_inference_value_memory
,
)
megatron/core/model_parallel_config.py
0 → 100644
View file @
3aca1415
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from
dataclasses
import
dataclass
from
typing
import
Callable
,
Optional
import
torch
@
dataclass
class
ModelParallelConfig
:
"""Base configuration for Megatron Core
Model Parallelism
-----------------
tensor_model_parallel_size (int): Intra-layer model parallelism. Splits tensors across GPU ranks. Defaults to 1.
pipeline_model_parallel_size (int): Inter-layer model parallelism. Splits transformer layers across GPU
ranks. Defaults to 1.
virtual_pipeline_model_parallel_size (int): Interleaved pipeline parallelism is used to improve performance by
reducing the pipeline bubble. Considers a transformer block as a list of smaller transformer (virtual) blocks.
The number of virtual blocks per pipeline model parallel rank is the virtual model parallel size. See Efficient
Large-Scale Language Model Training on GPU Clusters Using Megatron-LM: https://arxiv.org/pdf/2104.04473.pdf for
more details. Defaults to None.
sequence_parallel (bool): Makes tensor parallelism more memory efficient for LLMs (20B+) by
parallelizing layer norms and dropout sequentially. See Reducing Activation Recomputation in Large Transformer
Models: https://arxiv.org/abs/2205.05198 for more details. Defaults to False.
Initialization
--------------
perform_initialization (bool, default=True): If true, weights are initialized. This option can be useful when you
know you are going to load values from a checkpoint.
use_cpu_initialization: (bool, default=False): When set to False, we initialize the weights directly on the GPU.
Transferring weights from CPU to GPU can take a significant amount of time for large models. Defaults to False.
Training
--------
fp16 (bool): If true, train with fp16 mixed precision training. Defaults to False.
bf16 (bool): If true, train with bf16 mixed precision training. Defaults to False.
params_dtype (torch.dtype): dtype used when intializing the weights. Defaults to torch.float32
timers (optional, default=None): TODO
Optimizations
-------------
gradient_accumulation_fusion (bool): If true, fuses weight gradient accumulation to GEMMs. Requires the custom CUDA
extension fused_weight_gradient_mlp_cuda module. To use gradient_accumulation_fusion you must install APEX with
--cpp_ext and --cuda_ext. For example: "pip install --global-option=
\"
--cpp_ext
\"
--global-option=
\"
--cuda_ext
\"
". Note that the extension requires CUDA>=11. Otherwise, you must turn off gradient accumulation fusion.
Defaults to False.
async_tensor_model_parallel_allreduce (bool, default=True): If true, enables asynchronous execution of
tensor-model-parallel all-reduce with weight gradient compuation of a column-linear layer. Defaults to False.
Pipeline Parallelism
--------------------
pipeline_dtype (required): dtype used in p2p communication, usually params_dtype
grad_scale_func (optional, default=None): If using loss scaling, this function should take the loss and return the
scaled loss. If None, no function is called on the loss.
enable_autocast (bool): If true runs the forward step function inside torch.autocast context. Default is False.
autocast_dtype (torch.dtype): dtype to pass to torch.amp.autocast when enabled. Default is pipeline_dtype.
variable_seq_lengths (bool, default=False): Support for variable sequence lengths across microbatches. Setting this
communicates the size of tensors during pipeline parallelism communication, because of this extra overhead it
should only be set if the sequence length varies by microbatch within a global batch.
num_microbatches_with_partial_activation_checkpoints (int, default=None): If int, set the number of microbatches
where not all of the layers will be checkpointed and recomputed. The rest of the microbatches within the window
of maximum outstanding microbatches will recompute all layers (either full recompute or selective recompute). If
None, the checkpoint and recompute will be left up to the forward_step function.
overlap_p2p_comm (bool, optional, default=False): When True some of the peer to peer communication for pipeline
parallelism will overlap with computation. Must be False if batch_p2p_comm is true.
batch_p2p_comm (bool, default=True): Use batch_isend_irecv instead of individual isend/irecv calls. Must be False
if overlap_p2p_comm is True.
batch_p2p_sync (bool, default=True): When using batch_isend_irecv, do a cuda.device.synchronize afterward to work
around a bug in older version of PyTorch.
use_ring_exchange_p2p (bool, default = False): Use custom ring_exchange kernel instead of
torch.distributed.batch_isend_irecv(). Requires custom built torch with torch.distributed.ring_exchange.
deallocate_pipeline_outputs (optional, default=False): If True, output data is deallocated after the tensor is sent
to the next pipeline stage. Helps with saving memory, does nothing when pipeline parallel is not used.
no_sync_func (optional): Function that creates a context that suppresses asynchronous data-parallel
communication. If the model is an instance of torch.nn.DistributedDataParallel, the default is to use
torch.nn.DistributedDataParallel.no_sync.
grad_sync_func (optional): Function that launches asynchronous gradient reductions (e.g. distributed optimizer
gradient reduce-scatters). The function should take one argument: an iterable of parameters whose gradients are
to be synchronized.
param_sync_func (optional): Function that launches asynchronous parameter synchronizations (e.g. distributed
optimizer parameter all-gathers). The function should take one argument: an iterable of parameters to be
synchronized.
"""
# Model parallelism
tensor_model_parallel_size
:
int
=
1
pipeline_model_parallel_size
:
int
=
1
virtual_pipeline_model_parallel_size
:
Optional
[
int
]
=
None
sequence_parallel
:
bool
=
False
# Initialization
perform_initialization
:
bool
=
True
use_cpu_initialization
:
bool
=
False
# Training
fp16
:
bool
=
False
bf16
:
bool
=
False
params_dtype
:
torch
.
dtype
=
torch
.
float32
timers
:
Callable
=
None
# Optimizations
gradient_accumulation_fusion
:
bool
=
False
async_tensor_model_parallel_allreduce
:
bool
=
False
# Pipeline Parallel
pipeline_dtype
:
torch
.
dtype
=
None
grad_scale_func
:
Callable
=
None
enable_autocast
:
bool
=
False
autocast_dtype
:
torch
.
dtype
=
None
variable_seq_lengths
:
bool
=
False
num_microbatches_with_partial_activation_checkpoints
:
Optional
[
int
]
=
None
overlap_p2p_comm
:
bool
=
False
batch_p2p_comm
:
bool
=
True
batch_p2p_sync
:
bool
=
True
use_ring_exchange_p2p
:
bool
=
False
deallocate_pipeline_outputs
:
bool
=
False
no_sync_func
:
Callable
=
None
grad_sync_func
:
Callable
=
None
param_sync_func
:
Callable
=
None
def
__post_init__
(
self
):
""" Python dataclass method that is used to modify attributes after initialization.
See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more details.
"""
if
self
.
sequence_parallel
:
if
self
.
tensor_model_parallel_size
<=
1
:
raise
ValueError
(
"Can not use sequence paralllelism without tensor parallelism"
)
if
self
.
async_tensor_model_parallel_allreduce
:
# sequence_parallelism already does this async
self
.
async_tensor_model_parallel_allreduce
=
False
if
self
.
pipeline_model_parallel_size
>
1
:
if
self
.
pipeline_dtype
is
None
:
raise
ValueError
(
"When using pipeline parallelism, pipeline_dtype must be specified"
)
if
self
.
autocast_dtype
is
None
:
self
.
autocast_dtype
=
self
.
params_dtype
megatron/core/models/__init__.py
0 → 100644
View file @
3aca1415
megatron/core/models/common/__init__.py
0 → 100644
View file @
3aca1415
megatron/
model
/rotary_pos_embedding.py
→
megatron/
core/models/common
/rotary_pos_embedding.py
View file @
3aca1415
# coding=utf-8
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
# The following code has been taken from https://github.com/NVIDIA/NeMo/blob/ \
# 782b4e1652aaa43c8be390d9db0dc89544afa080/nemo/collections/nlp/modules/ \
# common/megatron/rotary_pos_embedding.py
import
importlib.util
import
importlib.util
import
torch
import
torch
from
torch
import
einsum
,
nn
from
torch
import
einsum
,
nn
__all__
=
[
'RotaryEmbedding'
,
'apply_rotary_pos_emb'
]
__all__
=
[
'RotaryEmbedding'
,
'apply_rotary_pos_emb'
]
class
RotaryEmbedding
(
nn
.
Module
):
class
RotaryEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
def
__init__
(
self
,
dim
,
seq_len_interpolation_factor
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
seq_len_interpolation_factor
=
seq_len_interpolation_factor
inv_freq
=
1.0
/
(
10000
**
(
torch
.
arange
(
0
,
dim
,
2
).
float
()
/
dim
))
inv_freq
=
1.0
/
(
10000
**
(
torch
.
arange
(
0
,
dim
,
2
).
float
()
/
dim
))
self
.
register_buffer
(
'inv_freq'
,
inv_freq
)
self
.
register_buffer
(
'inv_freq'
,
inv_freq
,
persistent
=
False
)
if
importlib
.
util
.
find_spec
(
'einops'
)
is
None
:
raise
RuntimeError
(
"einops is required for Rotary Embedding"
)
def
forward
(
self
,
max_seq_len
,
offset
=
0
):
def
forward
(
self
,
max_seq_len
,
offset
=
0
):
seq
=
torch
.
arange
(
max_seq_len
,
device
=
self
.
inv_freq
.
device
)
+
offset
seq
=
torch
.
arange
(
max_seq_len
,
device
=
self
.
inv_freq
.
device
)
+
offset
if
self
.
seq_len_interpolation_factor
is
not
None
:
seq
=
seq
.
type_as
(
self
.
inv_freq
)
seq
*=
1
/
self
.
seq_len_interpolation_factor
freqs
=
einsum
(
'i , j -> i j'
,
seq
.
type_as
(
self
.
inv_freq
),
self
.
inv_freq
)
freqs
=
einsum
(
'i , j -> i j'
,
seq
.
type_as
(
self
.
inv_freq
),
self
.
inv_freq
)
# first part even vector components, second part odd vector components,
# first part even vector components, second part odd vector components,
# 2 * dim in dimension size
# 2 * dim in dimension size
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
)
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
)
# emb [seq_length, .., dim]
# emb [seq_length, .., dim]
from
einops
import
rearrange
return
emb
[:,
None
,
None
,
:]
return
rearrange
(
emb
,
'n d -> n 1 1 d'
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
state_dict
.
pop
(
f
'
{
prefix
}
inv_freq'
,
None
)
return
super
().
_load_from_state_dict
(
state_dict
,
prefix
,
*
args
,
**
kwargs
)
def
_rotate_half
(
x
):
def
_rotate_half
(
x
):
"""
"""
change sign so the last dimension becomes [-odd, +even]
change sign so the last dimension becomes [-odd, +even]
"""
"""
from
einops
import
rearrange
x1
,
x2
=
torch
.
chunk
(
x
,
2
,
dim
=-
1
)
x
=
rearrange
(
x
,
'... (j d) -> ... j d'
,
j
=
2
)
x1
,
x2
=
x
.
unbind
(
dim
=-
2
)
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
...
...
megatron/core/models/gpt/__init__.py
0 → 100644
View file @
3aca1415
from
.gpt_model
import
GPTModel
megatron/core/models/gpt/gpt_embedding.py
0 → 100644
View file @
3aca1415
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import
torch
from
megatron.core
import
tensor_parallel
from
megatron.core.transformer.module
import
MegatronModule
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
megatron.core.utils
import
(
make_sharded_tensor_for_checkpoint
,
make_tp_sharded_tensor_for_checkpoint
,
)
class
GPTEmbedding
(
MegatronModule
):
"""Language model embeddings.
Arguments:
config (TransformerConfig): config object with all necessary configs for TransformerBlock
vocab_size (int): vocabulary size
max_sequence_length (int): maximum size of sequence. This
is used for positional embedding
add_position_embedding (bool): Add a position embedding.
embedding_dropout_prob float): dropout probability for embeddings
"""
def
__init__
(
self
,
config
:
TransformerConfig
,
vocab_size
:
int
,
max_sequence_length
:
int
,
add_position_embedding
:
bool
,
):
super
().
__init__
(
config
=
config
)
self
.
config
:
TransformerConfig
=
config
self
.
vocab_size
:
int
=
vocab_size
self
.
max_sequence_length
:
int
=
max_sequence_length
self
.
add_position_embedding
:
bool
=
add_position_embedding
# Word embeddings (parallel).
self
.
word_embeddings
=
tensor_parallel
.
VocabParallelEmbedding
(
num_embeddings
=
self
.
vocab_size
,
embedding_dim
=
self
.
config
.
hidden_size
,
init_method
=
self
.
config
.
init_method
,
config
=
self
.
config
,
)
# Position embedding (serial).
if
self
.
add_position_embedding
:
self
.
position_embeddings
=
torch
.
nn
.
Embedding
(
self
.
max_sequence_length
,
self
.
config
.
hidden_size
)
# Initialize the position embeddings.
if
self
.
config
.
perform_initialization
:
self
.
config
.
init_method
(
self
.
position_embeddings
.
weight
)
# Embeddings dropout
self
.
embedding_dropout
=
torch
.
nn
.
Dropout
(
self
.
config
.
hidden_dropout
)
def
zero_parameters
(
self
):
"""Zero out all parameters in embedding."""
self
.
word_embeddings
.
weight
.
data
.
fill_
(
0
)
self
.
word_embeddings
.
weight
.
shared
=
True
self
.
position_embeddings
.
weight
.
data
.
fill_
(
0
)
self
.
position_embeddings
.
weight
.
shared
=
True
def
forward
(
self
,
input_ids
,
position_ids
):
# Embeddings.
word_embeddings
=
self
.
word_embeddings
(
input_ids
)
if
self
.
add_position_embedding
:
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
embeddings
=
word_embeddings
+
position_embeddings
else
:
embeddings
=
word_embeddings
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
embeddings
=
embeddings
.
transpose
(
0
,
1
).
contiguous
()
# If the input flag for fp32 residual connection is set, convert for float.
if
self
.
config
.
fp32_residual_connection
:
embeddings
=
embeddings
.
float
()
# Dropout.
if
self
.
config
.
sequence_parallel
:
embeddings
=
tensor_parallel
.
scatter_to_sequence_parallel_region
(
embeddings
)
with
tensor_parallel
.
get_cuda_rng_tracker
().
fork
():
embeddings
=
self
.
embedding_dropout
(
embeddings
)
else
:
embeddings
=
self
.
embedding_dropout
(
embeddings
)
return
embeddings
def
sharded_state_dict
(
self
,
prefix
=
''
):
sharded_state_dict
=
{}
word_embeddings_prefix
=
f
'
{
prefix
}
word_embeddings.'
word_embeddings_state_dict
=
self
.
word_embeddings
.
state_dict
(
prefix
=
word_embeddings_prefix
,
keep_vars
=
True
)
sharded_word_embeddings_key
=
f
'
{
word_embeddings_prefix
}
weight'
sharded_word_embeddings_tensor
=
make_tp_sharded_tensor_for_checkpoint
(
tensor
=
word_embeddings_state_dict
[
sharded_word_embeddings_key
],
key
=
sharded_word_embeddings_key
,
allow_shape_mismatch
=
True
,
)
sharded_state_dict
[
sharded_word_embeddings_key
]
=
sharded_word_embeddings_tensor
if
self
.
add_position_embedding
:
position_embeddings_prefix
=
f
'
{
prefix
}
position_embeddings.'
position_embeddings_state_dict
=
self
.
position_embeddings
.
state_dict
(
prefix
=
position_embeddings_prefix
,
keep_vars
=
True
)
sharded_position_embeddings_key
=
f
'
{
position_embeddings_prefix
}
weight'
sharded_position_embeddings_tensor
=
make_sharded_tensor_for_checkpoint
(
tensor
=
position_embeddings_state_dict
[
sharded_position_embeddings_key
],
key
=
sharded_position_embeddings_key
,
)
sharded_state_dict
[
sharded_position_embeddings_key
]
=
sharded_position_embeddings_tensor
return
sharded_state_dict
megatron/core/models/gpt/gpt_model.py
0 → 100644
View file @
3aca1415
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
5
6
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment