Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
megatron-LM-llama
Commits
523ec9cc
Commit
523ec9cc
authored
Sep 03, 2024
by
wangsen
Browse files
all
parents
Pipeline
#1668
failed with stages
in 0 seconds
Changes
410
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3618 additions
and
0 deletions
+3618
-0
megatron/core/dist_checkpointing/strategies/__pycache__/fully_parallel.cpython-310.pyc
...ing/strategies/__pycache__/fully_parallel.cpython-310.pyc
+0
-0
megatron/core/dist_checkpointing/strategies/async_utils.py
megatron/core/dist_checkpointing/strategies/async_utils.py
+214
-0
megatron/core/dist_checkpointing/strategies/base.py
megatron/core/dist_checkpointing/strategies/base.py
+148
-0
megatron/core/dist_checkpointing/strategies/filesystem_async.py
...on/core/dist_checkpointing/strategies/filesystem_async.py
+288
-0
megatron/core/dist_checkpointing/strategies/fully_parallel.py
...tron/core/dist_checkpointing/strategies/fully_parallel.py
+826
-0
megatron/core/dist_checkpointing/strategies/state_dict_saver.py
...on/core/dist_checkpointing/strategies/state_dict_saver.py
+134
-0
megatron/core/dist_checkpointing/strategies/tensorstore.py
megatron/core/dist_checkpointing/strategies/tensorstore.py
+131
-0
megatron/core/dist_checkpointing/strategies/torch.py
megatron/core/dist_checkpointing/strategies/torch.py
+652
-0
megatron/core/dist_checkpointing/strategies/two_stage.py
megatron/core/dist_checkpointing/strategies/two_stage.py
+257
-0
megatron/core/dist_checkpointing/strategies/zarr.py
megatron/core/dist_checkpointing/strategies/zarr.py
+300
-0
megatron/core/dist_checkpointing/utils.py
megatron/core/dist_checkpointing/utils.py
+154
-0
megatron/core/distributed/__init__.py
megatron/core/distributed/__init__.py
+6
-0
megatron/core/distributed/__pycache__/__init__.cpython-310.pyc
...ron/core/distributed/__pycache__/__init__.cpython-310.pyc
+0
-0
megatron/core/distributed/__pycache__/distributed_data_parallel.cpython-310.pyc
...ted/__pycache__/distributed_data_parallel.cpython-310.pyc
+0
-0
megatron/core/distributed/__pycache__/distributed_data_parallel_config.cpython-310.pyc
...ycache__/distributed_data_parallel_config.cpython-310.pyc
+0
-0
megatron/core/distributed/__pycache__/finalize_model_grads.cpython-310.pyc
...tributed/__pycache__/finalize_model_grads.cpython-310.pyc
+0
-0
megatron/core/distributed/__pycache__/param_and_grad_buffer.cpython-310.pyc
...ributed/__pycache__/param_and_grad_buffer.cpython-310.pyc
+0
-0
megatron/core/distributed/distributed_data_parallel.py
megatron/core/distributed/distributed_data_parallel.py
+325
-0
megatron/core/distributed/distributed_data_parallel_config.py
...tron/core/distributed/distributed_data_parallel_config.py
+32
-0
megatron/core/distributed/finalize_model_grads.py
megatron/core/distributed/finalize_model_grads.py
+151
-0
No files found.
Too many changes to show.
To preserve performance only
410 of 410+
files are displayed.
Plain diff
Email patch
megatron/core/dist_checkpointing/strategies/__pycache__/fully_parallel.cpython-310.pyc
0 → 100644
View file @
523ec9cc
File added
megatron/core/dist_checkpointing/strategies/async_utils.py
0 → 100644
View file @
523ec9cc
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""
This module provides an async utilities which allow to start
a checkpoint save process in the background.
"""
import
logging
from
collections
import
deque
from
time
import
time
from
typing
import
Callable
,
List
,
NamedTuple
,
Optional
,
Tuple
import
torch
from
torch
import
multiprocessing
as
mp
logger
=
logging
.
getLogger
(
__name__
)
class
AsyncRequest
(
NamedTuple
):
""" Represents an async request that needs to be scheduled for execution.
Args:
async_fn (Callable, optional): async function to call. None represents noop.
async_fn_args (Tuple): args to pass to `async_fn`.
finalize_fns (List[Callable]): list of functions to call to finalize the request.
These functions will be called synchronously after `async_fn` is done
*on all ranks*.
"""
async_fn
:
Optional
[
Callable
]
async_fn_args
:
Tuple
finalize_fns
:
List
[
Callable
]
is_frozen
:
bool
=
False
def
add_finalize_fn
(
self
,
fn
:
Callable
)
->
None
:
""" Adds a new finalize function to the request.
Args:
fn (Callable): function to add to the async request. This function
will be called *after* existing finalization functions.
Returns:
None
"""
if
self
.
is_frozen
:
raise
RuntimeError
(
'Cannot add finalization functions to a frozen AsyncRequest'
)
self
.
finalize_fns
.
append
(
fn
)
def
execute_sync
(
self
)
->
None
:
""" Helper to synchronously execute the request.
This logic is equivalent to what should happen in case of the async call.
"""
if
self
.
async_fn
is
not
None
:
self
.
async_fn
(
*
self
.
async_fn_args
)
torch
.
distributed
.
barrier
()
for
finalize_fn
in
self
.
finalize_fns
:
finalize_fn
()
def
freeze
(
self
)
->
'AsyncRequest'
:
""" Freezes the async request, disallowing adding new finalization functions.
Returns:
AsyncRequest: new async request with all same fields except for the
`is_frozen` flag.
"""
return
self
.
_replace
(
is_frozen
=
True
)
class
DistributedAsyncCaller
:
""" Wrapper around mp.Process that ensures correct semantic of distributed finalization.
Starts process asynchronously and allows checking if all processes on all ranks are done.
"""
def
__init__
(
self
):
self
.
process
:
Optional
[
mp
.
Process
]
=
None
self
.
start_time
:
Optional
[
float
]
=
None
def
schedule_async_call
(
self
,
async_fn
:
Optional
[
Callable
],
save_args
:
Tuple
,)
->
None
:
""" Spawn a process with `async_fn` as the target.
This method must be called on all ranks.
Args:
async_fn (Callable, optional): async function to call. If None,
no process will be started.
save_args (Tuple): async function args.
"""
if
async_fn
is
None
:
return
# nothing to do
torch
.
cuda
.
synchronize
()
ctx
=
mp
.
get_context
(
'fork'
)
self
.
start_time
=
time
()
self
.
process
=
ctx
.
Process
(
target
=
async_fn
,
args
=
save_args
,)
self
.
process
.
start
()
def
is_current_async_call_done
(
self
,
blocking
=
False
)
->
bool
:
""" Check if async save is finished on all ranks.
For semantic correctness, requires rank synchronization in each check.
This method must be called on all ranks.
Args:
blocking (bool, optional): if True, will wait until the call is done
on all ranks. Otherwise, returns immediately if at least one rank
is still active. Defaults to False.
Returns:
bool: True if all ranks are done (immediately of after active wait
if `blocking` is True), False if at least one rank is still active.
"""
# The following takes the same overhead as torch.distributed.barrier (single integer all-reduce)
is_alive
=
int
(
self
.
process
.
is_alive
())
if
self
.
process
is
not
None
else
0
ten
=
torch
.
tensor
([
is_alive
],
dtype
=
torch
.
int
,
device
=
torch
.
cuda
.
current_device
())
logger
.
debug
(
f
"rank:
{
torch
.
distributed
.
get_rank
()
}
, DistributedAsyncCaller is_alive:
{
is_alive
}
"
)
torch
.
distributed
.
all_reduce
(
ten
)
if
ten
[
0
]
>
0
and
not
blocking
:
return
False
else
:
if
self
.
process
is
not
None
:
logger
.
debug
(
f
"rank:
{
torch
.
distributed
.
get_rank
()
}
, joining self.process"
)
self
.
process
.
join
()
self
.
process
=
None
logger
.
debug
(
f
"DistributedAsyncCaller: Async process join finished after
{
time
()
-
self
.
start_time
:.
2
f
}
s from forking"
)
self
.
start_time
=
None
return
True
class
_ActiveAsyncRequest
(
NamedTuple
):
""" Helper to represent an active async call.
Args:
idx (int): index of the call (starting from 0)
async_caller (DistributedAsyncCaller): async caller instance that represents
the async process handling the async request
async_request (AsyncRequest): async request that is being called
"""
idx
:
int
async_caller
:
DistributedAsyncCaller
async_request
:
AsyncRequest
class
AsyncCallsQueue
:
""" Manages a queue of async calls.
Allows adding a new async call with `schedule_async_request` and finalizing
active calls with `maybe_finalize_async_calls`.
"""
def
__init__
(
self
):
self
.
async_calls
:
deque
[
_ActiveAsyncRequest
]
=
deque
([])
self
.
call_idx
:
int
=
-
1
def
schedule_async_request
(
self
,
async_request
:
AsyncRequest
)
->
int
:
""" Start a new async call and add it to a queue of active async calls.
This method must be called on all ranks.
Args:
async_request (AsyncRequest): async request to start.
Returns:
int: index of the async call that was started.
This can help the user keep track of the async calls.
"""
self
.
call_idx
+=
1
async_caller
=
DistributedAsyncCaller
()
async_request
=
async_request
.
freeze
()
async_caller
.
schedule_async_call
(
async_request
.
async_fn
,
async_request
.
async_fn_args
)
self
.
async_calls
.
append
(
_ActiveAsyncRequest
(
self
.
call_idx
,
async_caller
,
async_request
))
return
self
.
call_idx
def
maybe_finalize_async_calls
(
self
,
blocking
=
False
)
->
List
[
int
]:
""" Finalizes all available calls.
This method must be called on all ranks.
Args:
blocking (bool, optional): if True, will wait until all active requests
are done. Otherwise, finalizes only the async request that already
finished. Defaults to False.
Returns:
List[int]: list of indices (as returned by `schedule_async_request`)
of async calls that have been successfully finalized.
"""
call_idx_finalized
=
[]
while
self
.
async_calls
:
next_async_done
=
self
.
async_calls
[
0
].
async_caller
.
is_current_async_call_done
(
blocking
)
if
not
next_async_done
:
break
call_idx
,
_
,
async_request
=
self
.
async_calls
.
popleft
()
for
finalize_fn
in
async_request
.
finalize_fns
:
finalize_fn
()
ten
=
torch
.
tensor
([
call_idx
],
dtype
=
torch
.
int
,
device
=
torch
.
cuda
.
current_device
())
torch
.
distributed
.
all_reduce
(
ten
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
)
assert
(
ten
.
item
()
==
call_idx
),
'Unmatched async calls. That probably means not all ranks are participating in async finalization'
call_idx_finalized
.
append
(
call_idx
)
return
call_idx_finalized
def
get_num_unfinalized_calls
(
self
):
""" Get the number of active async calls. """
return
len
(
self
.
async_calls
)
def
close
(
self
):
""" Finalize all calls upon closing. """
self
.
maybe_finalize_async_calls
(
blocking
=
True
)
megatron/core/dist_checkpointing/strategies/base.py
0 → 100644
View file @
523ec9cc
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Strategies base interfaces. """
from
abc
import
ABC
,
abstractmethod
from
collections
import
defaultdict
from
enum
import
Enum
from
pathlib
import
Path
from
..mapping
import
CheckpointingException
,
ShardedStateDict
,
StateDict
from
.async_utils
import
AsyncRequest
class
StrategyAction
(
Enum
):
LOAD_COMMON
=
'load_common'
LOAD_SHARDED
=
'load_sharded'
SAVE_COMMON
=
'save_common'
SAVE_SHARDED
=
'save_sharded'
default_strategies
=
defaultdict
(
dict
)
def
get_default_strategy
(
action
:
StrategyAction
,
backend
:
str
,
version
:
int
):
""" Retrieves a default strategy for a given action, backend and version. """
try
:
if
backend
==
'zarr'
:
error_hint
=
' Please install `zarr` and `tensorstore<=0.1.45` packages'
from
.tensorstore
import
_import_trigger
from
.zarr
import
_import_trigger
elif
backend
==
'torch_dist'
:
error_hint
=
' Please use PyTorch version >=2.1'
from
.torch
import
_import_trigger
except
ImportError
as
e
:
raise
CheckpointingException
(
f
'Cannot import a default strategy for:
{
(
action
.
value
,
backend
,
version
)
}
. Error:
{
e
}
. Hint:
{
error_hint
}
'
)
from
e
try
:
return
default_strategies
[
action
.
value
][(
backend
,
version
)]
except
KeyError
as
e
:
raise
CheckpointingException
(
f
'Cannot find a default strategy for:
{
(
action
.
value
,
backend
,
version
)
}
'
)
from
e
class
LoadStrategyBase
(
ABC
):
""" Base class for a load strategy. Requires implementing checks for compatibility with a given checkpoint version. """
@
abstractmethod
def
check_backend_compatibility
(
self
,
loaded_version
):
raise
NotImplementedError
@
abstractmethod
def
check_version_compatibility
(
self
,
loaded_version
):
raise
NotImplementedError
@
property
def
can_handle_sharded_objects
(
self
):
""" Returns whether or not this strategy can handle loading ShardedObjects. """
return
False
class
SaveStrategyBase
(
ABC
):
""" Base class for a save strategy. Requires defining a backend type and version of the saved format. """
def
__init__
(
self
,
backend
:
str
,
version
:
int
):
self
.
backend
=
backend
self
.
version
=
version
@
property
def
can_handle_sharded_objects
(
self
):
""" Returns whether or not this strategy can handle saving ShardedObjects. """
return
False
def
__str__
(
self
):
return
f
'
{
self
.
__class__
.
__name__
}
(
{
self
.
backend
}
,
{
self
.
version
}
)'
class
LoadCommonStrategy
(
LoadStrategyBase
):
""" Load strategy for common (non-sharded) objects """
@
abstractmethod
def
load
(
self
,
checkpoint_dir
:
Path
):
raise
NotImplementedError
class
LoadShardedStrategy
(
LoadStrategyBase
):
""" Load strategy for sharded tensors """
@
abstractmethod
def
load
(
self
,
sharded_state_dict
:
ShardedStateDict
,
checkpoint_dir
:
Path
):
raise
NotImplementedError
@
abstractmethod
def
load_tensors_metadata
(
self
,
checkpoint_dir
:
Path
):
"""Load tensors metadata from the checkpoint.
Returns a dictionary similar to a sharded state dict, but note that
the dictionary keys are simply ShardedTensor keys (contrary to the
actual sharded state dicts where keys correspond to state dict keys).
Dict values are ShardedTensors without any sharding (so, the only useful
information is tensors global shape and dtype).
"""
raise
NotImplementedError
(
f
'
{
self
.
__class__
.
__name__
}
doesnt allow loading only sharded metadata'
)
class
SaveCommonStrategy
(
SaveStrategyBase
):
""" Save strategy for common (non-sharded) objects """
@
abstractmethod
def
save
(
self
,
common_state_dict
:
StateDict
,
checkpoint_dir
:
Path
):
raise
NotImplementedError
class
SaveShardedStrategy
(
SaveStrategyBase
):
""" Save strategy for sharded tensors """
@
abstractmethod
def
save
(
self
,
sharded_state_dict
:
ShardedStateDict
,
checkpoint_dir
:
Path
):
raise
NotImplementedError
class
AsyncSaveShardedStrategy
(
SaveShardedStrategy
):
""" Save strategy suitable for async save. """
@
abstractmethod
def
async_save
(
self
,
sharded_state_dict
:
ShardedStateDict
,
checkpoint_dir
:
Path
)
->
AsyncRequest
:
""" Perform preparation and return an AsyncRequest to the external caller.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to save
checkpoint_dir (Path): checkpoint target directory
Returns:
AsyncRequest: represents the async save function and finalization function.
It is the caller responsibility to actually schedule the async save.
"""
raise
NotImplementedError
def
save
(
self
,
sharded_state_dict
:
ShardedStateDict
,
checkpoint_dir
:
Path
):
""" Each async strategy can be trivially used as a sync strategy. """
async_request
=
self
.
async_save
(
sharded_state_dict
,
checkpoint_dir
)
async_request
.
execute_sync
()
megatron/core/dist_checkpointing/strategies/filesystem_async.py
0 → 100644
View file @
523ec9cc
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
""" Storage writer for PyT Distributed format allowing asynchronous save. """
import
logging
import
os
from
itertools
import
chain
from
pathlib
import
Path
from
time
import
time
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
Tuple
import
psutil
import
torch
from
torch
import
multiprocessing
as
mp
from
torch.distributed.checkpoint
import
FileSystemWriter
from
torch.distributed.checkpoint.filesystem
import
DEFAULT_SUFFIX
,
_StoragePrefix
,
_write_item
from
torch.distributed.checkpoint.planner
import
SavePlan
,
SavePlanner
,
WriteItem
,
WriteItemType
from
torch.distributed.checkpoint.storage
import
WriteResult
from
torch.futures
import
Future
logger
=
logging
.
getLogger
(
__name__
)
WriteBucket
=
Tuple
[
Path
,
str
,
Tuple
[
list
,
list
]]
# represents writes to a single file
class
FileSystemWriterAsync
(
FileSystemWriter
):
"""
Async-enabled implementation of FileSystemWriter using file IO.
This class doesn't spawn the async process itself, relies on the external async mechanism.
Flow:
1. Call `write_data`
2. Externally start async process with `get_save_function_and_args` function and args
3. The async function to call is `writer_proxy_func` which calls
`write_preloaded_data` in multiple processes
After saving is finalized on all ranks:
4. Call `super().finish` with the results gathered in `self.writer_result`
Note that step (3) above can also be called synchronously.
Currently, it's assumed that a separate writer is created for each ckpt save
(intermediate state is stored as writer attributes).
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
if
not
self
.
single_file_per_rank
:
raise
NotImplementedError
(
'single_file_per_rank flag not supported for FileSystemWriterAsync'
)
# Intermediate state between preparation and finalization
self
.
write_buckets
:
Optional
[
List
[
WriteBucket
]]
=
None
self
.
write_results
:
Optional
[
Dict
[
int
,
List
[
WriteResult
]]]
=
None
def
prepare_write_data
(
self
,
plan
:
SavePlan
,
planner
:
SavePlanner
)
->
None
:
"""
First stage of async saving. Copy data to CPU and plan the local saving.
Args:
plan (SavePlan): save plan generated by the PyT Distributed compatible planner
planner (SavePlanner): save planner used to resolve the bytes and tensor data
Returns: None, but stores the save plan in `self.write_buckets`
"""
storage_plan
:
_StoragePrefix
=
plan
.
storage_data
start
=
time
()
logger
.
debug
(
f
"thread_count:
{
self
.
thread_count
}
, time:
{
start
}
"
)
item_buckets
=
_split_by_size_and_type
(
self
.
thread_count
,
plan
.
items
)
logger
.
debug
(
f
"bucket_prep, time:
{
time
()
-
start
}
"
)
start
=
time
()
# move tensors from GPU to CPU before starting async writing
# We do D2H synchronously for now
file_count
=
0
def
gen_file
():
nonlocal
file_count
file_name
=
f
"
{
storage_plan
.
prefix
}{
file_count
}{
DEFAULT_SUFFIX
}
"
file_count
+=
1
return
file_name
# Prepare bytes / tensor data in each bucket, which will be assigned to each writer process
self
.
write_buckets
=
[]
for
bucket
in
item_buckets
:
bytes_data
=
[
(
item
,
planner
.
resolve_data
(
item
))
for
item
in
bucket
if
item
.
type
==
WriteItemType
.
BYTE_IO
]
tensor_data
=
[
(
item
,
planner
.
resolve_data
(
item
).
detach
().
to
(
"cpu"
,
non_blocking
=
True
))
for
item
in
bucket
if
item
.
type
!=
WriteItemType
.
BYTE_IO
]
if
len
(
bytes_data
)
>
0
or
len
(
tensor_data
)
>
0
:
file_name
=
gen_file
()
self
.
write_buckets
.
append
(
(
self
.
path
/
file_name
,
file_name
,
(
bytes_data
,
tensor_data
))
)
# Check if there is anything to write on this rank
if
len
(
self
.
write_buckets
)
>
0
:
assert
len
(
self
.
write_buckets
)
<=
self
.
thread_count
,
(
len
(
self
.
write_buckets
),
self
.
thread_count
,
)
ctx
=
mp
.
get_context
(
'fork'
)
self
.
write_results
=
ctx
.
Manager
().
dict
()
else
:
self
.
write_results
=
{}
logger
.
debug
(
f
"D2H and push, time:
{
time
()
-
start
}
"
)
def
get_save_function_and_args
(
self
)
->
Tuple
[
Optional
[
Callable
],
Tuple
]:
"""
Get function that saves the data to storage along with its arguments.
Allows the external caller to apply the save function synchronously or asynchronously.
Returns: None (if there is nothing to write on this rank) or a tuple of:
- the function that saves the data
- arguments to that function
"""
if
not
self
.
write_buckets
:
return
None
,
()
return
(
self
.
write_preloaded_data_multiproc
,
(
self
.
write_buckets
,
self
.
write_results
))
@
staticmethod
def
write_preloaded_data_multiproc
(
write_buckets
:
List
[
WriteBucket
],
write_results
:
Dict
[
int
,
List
[
WriteResult
]]
)
->
None
:
"""
Performs saving data to storage with multiple processes.
Args:
write_buckets (List[WriteBucket]): write plan
write_results: (Dict[int, List[WriteResult]]): dict to store the write results to.
Assumes multiprocessing save, so keys are local process indices
Returns: None
"""
w_start
=
time
()
ctx
=
mp
.
get_context
(
'fork'
)
p_list
=
[
ctx
.
Process
(
target
=
FileSystemWriterAsync
.
write_preloaded_data
,
args
=
(
i
,
write_bucket
,
write_results
,
True
),
)
for
i
,
write_bucket
in
enumerate
(
write_buckets
)
]
for
p
in
p_list
:
p
.
start
()
for
p
in
p_list
:
p
.
join
()
w_end
=
time
()
logger
.
debug
(
f
"
{
w_end
}
, rank:
{
torch
.
distributed
.
get_rank
()
}
, write(sync,parallel):
{
w_end
-
w_start
}
"
)
@
staticmethod
def
write_preloaded_data
(
local_proc_idx
:
int
,
write_bucket
:
WriteBucket
,
write_results
:
Dict
[
int
,
List
[
WriteResult
]],
use_fsync
:
bool
,
)
->
None
:
"""
Performs actual data saving to storage.
Args:
local_proc_idx (int): index of a local process that performs writing
write_bucket (WriteBucket): data to write to storage
write_results (Dict[int, List[WriteResult]]): dict to store the write results to.
Assumes multiprocessing save, so keys are local process indices
use_fsync (bool): if True, calls os.fsync at the end of saving
Returns: None, the write result are written to the `write_results` dict
"""
mem_before
=
_process_memory
()
local_results
=
[]
file_name
,
storage_key
,
(
bytes_data
,
tensor_data
)
=
write_bucket
with
open
(
file_name
,
"wb"
)
as
stream
:
for
write_item
,
data
in
bytes_data
:
local_results
.
append
(
_write_item
(
stream
,
data
,
write_item
,
storage_key
))
for
write_item
,
tensor
in
tensor_data
:
assert
tensor
.
is_cpu
local_results
.
append
(
_write_item
(
stream
,
tensor
,
write_item
,
storage_key
))
if
use_fsync
:
os
.
fsync
(
stream
.
fileno
())
write_results
[
local_proc_idx
]
=
local_results
mem_after
=
_process_memory
()
logger
.
debug
(
f
"
{
local_proc_idx
}
consumed:
{
mem_after
-
mem_before
}
, before:
{
mem_before
}
, after:
{
mem_after
}
"
)
def
write_data
(
self
,
plan
:
SavePlan
,
planner
:
SavePlanner
,)
->
Future
[
List
[
WriteResult
]]:
raise
NotImplementedError
(
'write_data not implemented for FileSystemWriterAsync'
)
def
retrieve_write_results
(
self
)
->
List
[
WriteResult
]:
"""
Turn self.write_results into a single results lists. Includes error check.
Returns (List[WriteResult]): the list of write results from all local processes performing the save.
"""
assert
self
.
write_results
is
not
None
assert
self
.
write_buckets
is
not
None
if
len
(
self
.
write_results
)
!=
len
(
self
.
write_buckets
):
raise
RuntimeError
(
f
'Incomplete worker results (expected
{
len
(
self
.
write_buckets
)
}
, got
{
len
(
self
.
write_results
)
}
.'
f
' This probably indicates a worker failure.'
)
return
list
(
chain
.
from_iterable
(
self
.
write_results
.
values
()))
def
_split_by_size_and_type
(
bins
:
int
,
items
:
List
[
WriteItem
])
->
List
[
List
[
WriteItem
]]:
"""
Splits write items according to item size into close to uniform bins.
Same as torch.distributed.checkpoint.filesystem._split_by_size_and_type,
but with a fixed _item_size function.
Args:
bins (int): numbers of bins to split to
items (List[WriteItem]): list of write items
Returns (List[List[WriteItem]]): write items split to bins
"""
if
bins
==
1
:
return
[
items
]
bytes_items
=
[
wi
for
wi
in
items
if
wi
.
type
==
WriteItemType
.
BYTE_IO
]
tensor_items
=
[
wi
for
wi
in
items
if
wi
.
type
!=
WriteItemType
.
BYTE_IO
]
buckets
:
List
[
List
[
WriteItem
]]
=
[[]
for
_
in
range
(
bins
)]
bucket_sizes
=
[
0
for
_
in
range
(
bins
)]
tensor_items
.
sort
(
key
=
_item_size
,
reverse
=
True
)
# Assign bytes with a simple round-robin
for
i
,
item
in
enumerate
(
bytes_items
):
buckets
[
i
%
bins
].
append
(
item
)
# Then, assign tensors according to their sizes
for
item
in
tensor_items
:
# TODO replace with headq
idx
=
min
(
enumerate
(
bucket_sizes
),
key
=
lambda
x
:
x
[
1
])[
0
]
buckets
[
idx
].
append
(
item
)
bucket_sizes
[
idx
]
+=
_item_size
(
item
)
return
buckets
def
_item_size
(
item
:
WriteItem
)
->
int
:
"""
Calculates size (in bytes) of a single write item.
Same as torch.distributed.checkpoint.filesystem._item_size,
but fixes computing chunk size (with item.tensor_data.chunk.sizes)
Args:
item (WriteItem): write item to compute the size of
Returns (int): size of an item in bytes
"""
size
=
1
assert
item
.
tensor_data
is
not
None
# can't use math.prod as PT needs to support older python
for
s
in
item
.
tensor_data
.
chunk
.
sizes
:
size
*=
s
dtype
=
item
.
tensor_data
.
properties
.
dtype
return
size
*
torch
.
_utils
.
_element_size
(
dtype
)
def
_process_memory
()
->
int
:
"""
Get memory used by current process.
Returns (int): memory used by current process
"""
process
=
psutil
.
Process
(
os
.
getpid
())
mem_info
=
process
.
memory_info
()
return
mem_info
.
rss
megatron/core/dist_checkpointing/strategies/fully_parallel.py
0 → 100644
View file @
523ec9cc
import
logging
from
collections
import
defaultdict
from
functools
import
reduce
from
itertools
import
zip_longest
from
pathlib
import
Path
from
time
import
time
from
typing
import
Dict
,
List
,
NamedTuple
,
Optional
,
Set
,
Tuple
,
TypeVar
,
cast
import
numpy
as
np
import
torch
import
torch.distributed
as
dist
from
megatron.core.dist_checkpointing
import
ShardedTensor
from
megatron.core.dist_checkpointing.core
import
CheckpointingException
from
megatron.core.dist_checkpointing.dict_utils
import
(
dict_list_map_inplace
,
extract_matching_values
,
merge
,
nested_values
,
)
from
megatron.core.dist_checkpointing.mapping
import
ShardedStateDict
,
StateDict
,
is_main_replica
from
megatron.core.dist_checkpointing.serialization
import
validate_sharding_integrity
from
megatron.core.dist_checkpointing.strategies.base
import
(
AsyncSaveShardedStrategy
,
LoadShardedStrategy
,
SaveShardedStrategy
,
)
logger
=
logging
.
getLogger
(
__name__
)
# _ShardId uniquely identifies a ShardedTensor. This is a subset of ShardedTensor
# attributes: key (str), global_offset (tuple) and flattened_range (optional tuple)
_ShardId
=
Tuple
[
str
,
tuple
,
Optional
[
tuple
]]
class
SaveLoadDistribution
(
NamedTuple
):
""" Represents a save or load distribution of ShardedTensors.
Given distribution is valid only for a specific parallelization group,
which is implicit here (not referenced by this class).
Args:
main_rank_for_shard (Dict[_ShardId, int]): specifies which rank should hold
the main replica for a given shard
shards_in_this_group (Set[_ShardId]): which shards have a main replica
in this parallelization group
shard_to_metadata (Dict[_ShardId, ShardedTensor]): maps ShardedTensor
identifier to the original ShardedTensor
"""
main_rank_for_shard
:
Dict
[
_ShardId
,
int
]
shards_in_this_group
:
Set
[
_ShardId
]
shard_to_metadata
:
Dict
[
_ShardId
,
ShardedTensor
]
class
FullyParallelSaveStrategyWrapper
(
AsyncSaveShardedStrategy
):
""" Wraps arbitrary strategy and distributes the save during `save`.
The save distribution happens without any *data* communication.
Only the *metadata* is exchanged and based on data replication on different
ranks, we try to distribute the save as uniformly as possible.
This wrapper assumes, that setting `replica_id` to 0 will make the
underlying strategy do the saving on current rank. All the other `replica_id`s
are set to 1.
Currently, the save distribution is realized with a greedy algorithm
described in `distribute_shards_to_ranks`.
Args:
strategy (SaveShardedStrategy): base strategy to wrap
parallelization_group (ProcessGroup, optional): process group to use for save
distribution. Note that this doesn't have to match exactly the
data distribution, but should cover the replication pattern
to maximize performance. Defaults to the whole world.
do_cache_distribution (bool, optional): whether to cache the save distribution
from previous calls. Should be set to True only if the state dict
structure between the calls is always the same. Defaults to True.
"""
def
__init__
(
self
,
strategy
:
SaveShardedStrategy
,
parallelization_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
do_cache_distribution
:
bool
=
False
,
):
super
().
__init__
(
strategy
.
backend
,
strategy
.
version
)
self
.
base_strategy
=
strategy
self
.
parallelization_group
=
parallelization_group
self
.
do_cache_distribution
=
do_cache_distribution
self
.
cached_distribution
:
Optional
[
SaveLoadDistribution
]
=
None
def
async_save
(
self
,
sharded_state_dict
:
ShardedStateDict
,
checkpoint_dir
:
Path
):
if
not
isinstance
(
self
.
base_strategy
,
AsyncSaveShardedStrategy
):
raise
CheckpointingException
(
f
'Cannot apply async_save to non-async base strategy
{
self
.
base_strategy
}
'
)
self
.
apply_saving_parallelization
(
sharded_state_dict
)
return
self
.
base_strategy
.
async_save
(
sharded_state_dict
,
checkpoint_dir
)
def
save
(
self
,
sharded_state_dict
:
ShardedStateDict
,
checkpoint_dir
:
Path
):
self
.
apply_saving_parallelization
(
sharded_state_dict
)
return
self
.
base_strategy
.
save
(
sharded_state_dict
,
checkpoint_dir
)
def
apply_saving_parallelization
(
self
,
sharded_state_dict
:
ShardedStateDict
)
->
None
:
""" Distributes the save across ranks by exchanging metadata.
Exchanges metadata from the state dict and computes the uniform
(as close as possible) distribution of saves among the ranks.
If `self.do_cache_distribution` is True, caches the distribution between
the calls and subsequent distributions happen without any inter-rank
communication.
Args:
sharded_state_dict (ShardedStateDict): state dict to distribute the saving
Returns: None
"""
if
self
.
do_cache_distribution
and
self
.
cached_distribution
is
not
None
:
logger
.
debug
(
f
'Apply *cached* save parallelization'
)
precomputed_distribution
=
self
.
cached_distribution
else
:
logger
.
debug
(
f
'Apply save parallelization'
)
precomputed_distribution
=
determine_main_replica_uniform_distribution
(
sharded_state_dict
,
self
.
parallelization_group
)
distribute_main_replicas_with_precomputed_distribution
(
sharded_state_dict
,
self
.
parallelization_group
,
precomputed_distribution
)
if
self
.
cached_distribution
is
None
:
# First time applying the parallelization
validate_sharding_integrity
(
nested_values
(
sharded_state_dict
))
if
self
.
do_cache_distribution
:
self
.
cached_distribution
=
precomputed_distribution
@
property
def
can_handle_sharded_objects
(
self
):
return
self
.
base_strategy
.
can_handle_sharded_objects
class
FullyParallelLoadStrategyWrapper
(
LoadShardedStrategy
):
""" Wraps arbitrary load strategy and distributes the load during `load`.
See `load` method docs for details.
Args:
strategy (LoadShardedStrategy): base strategy to wrap
parallelization_group (ProcessGroup, optional): process group to use for load
distribution. Note that this doesn't have to match exactly the
data distribution, but should cover the replication pattern
to maximize performance. Defaults to the whole world.
In most cases, it's recommended to set it to the DP group.
do_cache_distribution (bool, optional): whether to cache the load distribution
from previous calls. Should be set to True only if the state dict
structure between the calls is always the same. Defaults to False,
since the loading in general happens only once during training.
Note that the load distribution *cannot* be reused as a save distribution,
because save/load is not fully symmetrical.
exchange_algo (str): algorithm to use for exchanging the data.
Options:
- broadcast - each rank broadcasts individual tensors to others
- gather_object (default) - ranks all_gather_object the whole loaded state dicts
- gather_rounds (default) - ranks all gather individual tensors in rounds
See method docs for more details.
"""
def
__init__
(
self
,
strategy
:
LoadShardedStrategy
,
parallelization_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
do_cache_distribution
:
bool
=
False
,
exchange_algo
:
str
=
'gather_rounds'
,
):
super
().
__init__
()
self
.
base_strategy
=
strategy
self
.
parallelization_group
=
parallelization_group
self
.
do_cache_distribution
=
do_cache_distribution
self
.
exchange_algo
=
exchange_algo
self
.
cached_distribution
:
Optional
[
SaveLoadDistribution
]
=
None
def
load
(
self
,
sharded_state_dict
:
ShardedStateDict
,
checkpoint_dir
:
Path
)
->
StateDict
:
""" Distributes the load and calls underlying strategy only for parts of the state dict.
Steps:
1. Load metadata is exchanged between the ranks in the parallelization group.
2. Each rank deterministically plans the load for the whole workload
so that the loads are as uniform as possible.
3. Each ranks loads its planned shard of the checkpoint.
4. All ranks exchange the loaded shards.
Internode communication is involved in steps (1) (with metadata)
and (4) (with actual data). Storage interaction is involved in step (3).
Currently, the load distribution (step 2) is realized with a greedy algorithm
described in `distribute_shards_to_ranks` (same as for saving distribution).
Currently, the shards are all gathered between all ranks in the parallelization
group. This might not be optimal (some ranks do not need all tensors),
but it's a reasonable approximation for an optimal exchange in most scenarios.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to load
checkpoint_dir (Path): checkpoint directory to load from
Returns:
StateDict: loaded state dict. The state dict should be equivalent to
a state dict that would be loaded with the underlying strategy
without this wrapper.
"""
if
torch
.
distributed
.
get_world_size
(
self
.
parallelization_group
)
<=
1
:
return
self
.
base_strategy
.
load
(
sharded_state_dict
,
checkpoint_dir
)
# Step 1 and 2: exchange load metadata and distribute the load
start
=
time
()
precomputed_distribution
=
self
.
apply_loading_parallelization
(
sharded_state_dict
)
assert
(
precomputed_distribution
is
not
None
),
'Expecting non-trivial distribution for non-trivial parallelization group'
end
=
time
()
logger
.
debug
(
f
'self.apply_loading_parallelization took
{
end
-
start
}
s'
)
start
=
end
# Step 3: load part of the checkpoint.
# Load only sharded objects first. ShardedTensors will be loaded separately
# so that we can keep track of sharded tensors loaded by this rank
(
sharded_tensors
,
sharded_state_dict
,
to_load_shards
,
unloaded_shards
,
)
=
self
.
_defer_loading_sharded_tensors
(
sharded_state_dict
)
loaded_state_dict
=
self
.
base_strategy
.
load
(
sharded_state_dict
,
checkpoint_dir
)
end
=
time
()
logger
.
debug
(
f
'Base load of ShardedObjects took
{
end
-
start
}
s'
)
start
=
end
# Load sharded tensors separately
loaded_tensors
=
self
.
base_strategy
.
load
(
to_load_shards
,
checkpoint_dir
)
end
=
time
()
logger
.
debug
(
f
'Base load of ShardedTensors took
{
end
-
start
}
s'
)
start
=
end
# Step 4: exchange data between ranks
logger
.
debug
(
f
'Applying parallel load with algo
{
self
.
exchange_algo
}
'
)
if
self
.
exchange_algo
==
'gather_object'
:
exchange_fn
=
self
.
exchange_loaded_tensors_gather_object
elif
self
.
exchange_algo
==
'gather_rounds'
:
exchange_fn
=
self
.
exchange_loaded_tensors_gather_rounds
elif
self
.
exchange_algo
==
'broadcast'
:
exchange_fn
=
self
.
exchange_loaded_tensors_broadcast
else
:
raise
NotImplementedError
(
f
'Unrecognized gather algorithm:
{
self
.
exchange_algo
}
'
)
all_loaded_tensors
=
exchange_fn
(
loaded_tensors
,
unloaded_shards
,
precomputed_distribution
,
self
.
parallelization_group
,
)
if
not
set
(
unloaded_shards
.
keys
()).
issubset
(
all_loaded_tensors
.
keys
()):
missing_shards
=
set
(
unloaded_shards
.
keys
())
-
all_loaded_tensors
.
keys
()
raise
CheckpointingException
(
f
'Missing shards after fully parallel loading:
{
missing_shards
}
'
)
sync_start
=
time
()
torch
.
cuda
.
synchronize
()
end
=
time
()
logger
.
debug
(
f
'torch.cuda.synchronize took
{
end
-
sync_start
}
s'
)
logger
.
debug
(
f
'self.exchange_loaded_tensors took
{
end
-
start
}
s'
)
self
.
fill_in_deferred_sharded_tensors
(
sharded_tensors
,
all_loaded_tensors
)
merge
(
loaded_state_dict
,
sharded_tensors
)
return
loaded_state_dict
def
_defer_loading_sharded_tensors
(
self
,
sharded_state_dict
:
ShardedStateDict
)
->
Tuple
[
ShardedStateDict
,
ShardedStateDict
,
Dict
[
_ShardId
,
ShardedTensor
],
Dict
[
_ShardId
,
ShardedTensor
],
]:
""" Divides state dict into parts loaded by this vs other ranks.
ShardedTensors with main replica_id will be loaded by this rank,
others will be received by other ranks (after loading from storage).
Args:
sharded_state_dict (ShardedStateDict): state dict with ShardedTensor
that will be divided.
Returns: a tuple of:
- ShardedStateDict: sub-state dict only with ShardedTensors
- ShardedStateDict: sub-state dict with non-ShardedTensors
- Dict[_ShardId, ShardedTensor]: ShardedTensor are uniquely identified
by shard ids. This is a mapping from shard id to a corresponding
ShardedTensor for tensors loaded by *this* rank
- Dict[_ShardId, ShardedTensor]: mapping from shard id to a corresponding
ShardedTensor for tensors loaded by *other* ranks
"""
to_load_shards
=
{}
unloaded_shards
=
{}
sharded_tensors
,
sharded_state_dict
=
extract_matching_values
(
sharded_state_dict
,
lambda
v
:
isinstance
(
v
,
ShardedTensor
)
)
def
wrap_non_main_replicas
(
x
):
if
isinstance
(
x
,
ShardedTensor
):
# Assign shard to be loaded or not
if
is_main_replica
(
x
.
replica_id
):
to_load_shards
[
_sharded_tensor_shard_id
(
x
)]
=
x
else
:
unloaded_shards
[
_sharded_tensor_shard_id
(
x
)]
=
x
return
x
dict_list_map_inplace
(
wrap_non_main_replicas
,
sharded_tensors
)
return
sharded_tensors
,
sharded_state_dict
,
to_load_shards
,
unloaded_shards
def
apply_loading_parallelization
(
self
,
sharded_state_dict
:
ShardedStateDict
)
->
Optional
[
SaveLoadDistribution
]:
""" Distributes the load across ranks by exchanging metadata.
Exchanges metadata from the state dict and computes the uniform
(as close as possible) distribution of loads among the ranks.
Marks ShardedTensors to be loaded by the current rank with replica_id 0
(and others with non 0 values).
If `self.do_cache_distribution` is True, caches the distribution between
the calls and subsequent distributions happen without any inter-rank
communication.
Args:
sharded_state_dict (ShardedStateDict): state dict to distribute the loading
Returns:
SaveLoadDistribution (optional): the computed loading distribution
"""
if
self
.
do_cache_distribution
and
self
.
cached_distribution
is
not
None
:
logger
.
debug
(
f
'Apply *cached* load parallelization'
)
precomputed_distribution
=
self
.
cached_distribution
else
:
logger
.
debug
(
f
'Apply load parallelization'
)
precomputed_distribution
=
determine_main_replica_uniform_distribution
(
sharded_state_dict
,
self
.
parallelization_group
,
True
)
distribute_main_replicas_with_precomputed_distribution
(
sharded_state_dict
,
self
.
parallelization_group
,
precomputed_distribution
)
if
self
.
do_cache_distribution
:
self
.
cached_distribution
=
precomputed_distribution
return
precomputed_distribution
def
exchange_loaded_tensors_gather_object
(
self
,
loaded_tensors
:
Dict
[
_ShardId
,
torch
.
Tensor
],
unloaded_shards
:
Dict
[
_ShardId
,
ShardedTensor
],
precomputed_distribution
:
SaveLoadDistribution
,
parallelization_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
)
->
Dict
[
_ShardId
,
torch
.
Tensor
]:
""" Exchange the tensors loaded by different ranks with a simple all_gather_object call.
This version can be used for debugging purposes do to its simplistic
implementation. Shouldn't be used if performance is important.
Args:
loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor
shard ids to tensors already loaded by this rank.
unloaded_shards (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor
shard ids to ShardedTensors that aren't loaded yet.
precomputed_distribution (SaveLoadDistribution): uniform load distribution
parallelization_group (ProcessGroup, optional): process group used for load
distribution. Tensors will be exchanged within this group
Returns:
Dict[_ShardId, torch.Tensor]: dictionary mapping shard ids to tensors
needed by this rank to load a given state dict. Includes
previously loaded tensors (from `loaded_tensors` input)
"""
all_loaded_tensors_list
=
[
None
]
*
torch
.
distributed
.
get_world_size
(
group
=
parallelization_group
)
torch
.
distributed
.
all_gather_object
(
all_loaded_tensors_list
,
loaded_tensors
,
group
=
parallelization_group
)
all_loaded_tensors_list
=
cast
(
List
[
Dict
[
_ShardId
,
torch
.
Tensor
]],
all_loaded_tensors_list
)
all_loaded_tensors
=
reduce
(
lambda
x
,
y
:
{
**
x
,
**
y
},
all_loaded_tensors_list
)
# Error checks
if
len
(
all_loaded_tensors
)
!=
sum
(
map
(
len
,
all_loaded_tensors_list
)):
err_msg
=
'Duplicate shard ids loaded by different ranks'
if
torch
.
distributed
.
get_rank
()
==
0
:
logger
.
error
(
f
'
{
err_msg
}
. Shards ids by rank:
{
[
lt
.
keys
()
for
lt
in
all_loaded_tensors_list
]
}
'
)
raise
CheckpointingException
(
err_msg
)
return
all_loaded_tensors
@
torch
.
no_grad
()
def
exchange_loaded_tensors_gather_rounds
(
self
,
loaded_tensors
:
Dict
[
_ShardId
,
torch
.
Tensor
],
unloaded_shards
:
Dict
[
_ShardId
,
ShardedTensor
],
precomputed_distribution
:
SaveLoadDistribution
=
None
,
parallelization_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
)
->
Dict
[
_ShardId
,
torch
.
Tensor
]:
""" Exchange the tensors loaded by different ranks with several all_gather calls.
Groups tensors by dtype, divide tensors that will be exchanged into rounds
and execute all_gather for tensors from each round.
Note: the loading is distributed across ranks based on total loaded size
in bytes, so there is no guarantee that number of rounds needed for each
rank will be similar, which might result in a lot of almost empty
all_gathers. The solution would be to group all tensors into a one
bytes tensor and do a single all_gather (with similarly sized messages).
Args:
loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor
shard ids to tensors already loaded by this rank.
unloaded_shards (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor
shard ids to ShardedTensors that aren't loaded yet.
precomputed_distribution (SaveLoadDistribution): uniform load distribution
parallelization_group (ProcessGroup, optional): process group used for load
distribution. Tensors will be exchanged within this group
Returns:
Dict[_ShardId, torch.Tensor]: dictionary mapping shard ids to tensors
needed by this rank to load a given state dict. Includes
previously loaded tensors (from `loaded_tensors` input)
"""
shard_to_saving_rank
,
_
,
shard_to_metadata
=
precomputed_distribution
local_rank
=
torch
.
distributed
.
get_rank
(
group
=
self
.
parallelization_group
)
all_loaded_tensors
=
dict
(
loaded_tensors
)
# Group by dtype so that we all_gather tensors of the same dtype
for
dtype
in
sorted
(
set
(
map
(
lambda
sh_ten
:
sh_ten
.
dtype
,
shard_to_metadata
.
values
())),
key
=
str
):
start
=
time
()
# shards_by_rank maps rank to tensors loaded by this rank
shards_by_rank
:
List
[
List
[
torch
.
Tensor
]]
=
[
[]
for
_
in
range
(
torch
.
distributed
.
get_world_size
(
group
=
parallelization_group
))
]
for
shard_id
,
rank
in
shard_to_saving_rank
.
items
():
if
shard_to_metadata
[
shard_id
].
dtype
==
dtype
:
shards_by_rank
[
rank
].
append
(
shard_id
)
# Transpose `shards_by_rank` to form exchange rounds
shards_by_round
=
zip_longest
(
*
shards_by_rank
,
fillvalue
=
None
)
for
round_idx
,
round_shard_ids
in
enumerate
(
shards_by_round
):
round_tensors
=
[]
for
rank
,
shard_id
in
enumerate
(
round_shard_ids
):
if
shard_id
is
None
:
# if no more useful data, the given rank will exchange empty tensor
local_ten
=
torch
.
empty
(
0
,
dtype
=
dtype
,
device
=
'cuda'
)
else
:
assert
isinstance
(
shard_id
,
tuple
),
type
(
shard_id
)
if
rank
==
local_rank
:
assert
shard_id
in
all_loaded_tensors
,
(
shard_id
,
all_loaded_tensors
.
keys
(),
)
all_loaded_tensors
[
shard_id
]
=
all_loaded_tensors
[
shard_id
].
cuda
()
local_ten
=
all_loaded_tensors
[
shard_id
]
else
:
local_ten
=
self
.
_get_empty_tensor_for_exchange
(
shard_id
,
shard_to_metadata
,
unloaded_shards
,
all_loaded_tensors
)
round_tensors
.
append
(
local_ten
)
torch
.
distributed
.
all_gather
(
list
(
round_tensors
),
round_tensors
[
local_rank
],
group
=
self
.
parallelization_group
,
async_op
=
True
,
)
del
round_tensors
# remove tensor references
end
=
time
()
if
torch
.
distributed
.
get_rank
()
==
0
:
logger
.
debug
(
f
'
{
dtype
}
exchange rounds all_gather schedule took
{
end
-
start
}
s'
)
return
all_loaded_tensors
@
torch
.
no_grad
()
def
exchange_loaded_tensors_broadcast
(
self
,
loaded_tensors
:
Dict
[
_ShardId
,
torch
.
Tensor
],
unloaded_shards
:
Dict
[
_ShardId
,
ShardedTensor
],
precomputed_distribution
:
SaveLoadDistribution
=
None
,
parallelization_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
)
->
Dict
[
_ShardId
,
torch
.
Tensor
]:
""" Exchange the tensors loaded by different ranks by a series of broadcasts.
For each rank for each loaded tensor do a broadcast to the whole group.
A reasonable tradeoff in terms of performance and simplicity.
Args:
loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor
shard ids to tensors already loaded by this rank.
unloaded_shards (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor
shard ids to ShardedTensors that aren't loaded yet.
precomputed_distribution (SaveLoadDistribution): uniform load distribution
parallelization_group (ProcessGroup, optional): process group used for load
distribution. Tensors will be exchanged within this group
Returns:
Dict[_ShardId, torch.Tensor]: dictionary mapping shard ids to tensors
needed by this rank to load a given state dict. Includes
previously loaded tensors (from `loaded_tensors` input)
"""
shard_to_saving_rank
,
_
,
shard_to_metadata
=
precomputed_distribution
local_rank
=
torch
.
distributed
.
get_rank
(
group
=
self
.
parallelization_group
)
all_loaded_tensors
=
dict
(
loaded_tensors
)
start
=
time
()
for
shard_id
,
rank
in
shard_to_saving_rank
.
items
():
if
rank
==
local_rank
:
assert
shard_id
in
all_loaded_tensors
,
(
shard_id
,
all_loaded_tensors
.
keys
())
all_loaded_tensors
[
shard_id
]
=
all_loaded_tensors
[
shard_id
].
cuda
()
local_ten
=
all_loaded_tensors
[
shard_id
]
else
:
local_ten
=
self
.
_get_empty_tensor_for_exchange
(
shard_id
,
shard_to_metadata
,
unloaded_shards
,
all_loaded_tensors
)
global_src_rank
=
torch
.
distributed
.
get_global_rank
(
parallelization_group
,
rank
)
torch
.
distributed
.
broadcast
(
local_ten
,
src
=
global_src_rank
,
group
=
parallelization_group
,
async_op
=
True
)
end
=
time
()
if
torch
.
distributed
.
get_rank
()
==
0
:
logger
.
debug
(
f
'exchange broadcast schedule took
{
end
-
start
}
s'
)
return
all_loaded_tensors
def
_get_empty_tensor_for_exchange
(
self
,
shard_id
:
_ShardId
,
needed_shards
:
Dict
[
_ShardId
,
ShardedTensor
],
unneeded_shards
:
Dict
[
_ShardId
,
ShardedTensor
],
loaded_tensors
:
Dict
[
_ShardId
,
torch
.
Tensor
],
)
->
torch
.
Tensor
:
""" Determines the empty tensor to use for exchange.
If shard_id is needed by this rank, it will be in the `unloaded_shards`.
Otherwise, the metadata for this tensor can be found in `shard_to_metadata`
Args:
shard_id (_ShardId): shard_id that will be exchanged
needed_shards (Dict[_ShardId, ShardedTensor]): mapping from shard ids
to metadata for shards needed by this rank
unneeded_shards (Dict[_ShardId, ShardedTensor]): mapping from shard ids
to metadata for shards that can be discarded after exchange
loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping where useful tensors
are placed in
Returns:
torch.Tensor: empty tensor to be exchanged
"""
local_unloaded_sh_ten
=
needed_shards
.
get
(
shard_id
)
if
local_unloaded_sh_ten
is
None
:
sh_ten
=
unneeded_shards
[
shard_id
]
sh_ten
.
init_data
(
'cuda'
)
tensor
=
sh_ten
.
data
sh_ten
.
data
=
None
# won't be used. free memory
else
:
local_unloaded_sh_ten
.
init_data
(
'cuda'
)
tensor
=
local_unloaded_sh_ten
.
data
loaded_tensors
[
shard_id
]
=
tensor
return
tensor
def
fill_in_deferred_sharded_tensors
(
self
,
sharded_state_dict
:
ShardedStateDict
,
loaded_tensors
:
Dict
[
_ShardId
,
torch
.
Tensor
]
)
->
None
:
""" Fill in tensors not loaded by current rank with tensors from `loaded_tensors` map.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to fill in.
ShardedTensors are completely replaced with corresponding torch.Tensors.
loaded_tensors (Dict[_ShardId, torch.Tensor]): dict allowing to map
ShardedTensor from the sharded_state_dict to loaded tensors.
Returns:
"""
def
fill_in_sharded_tensor
(
x
):
if
isinstance
(
x
,
ShardedTensor
):
try
:
x
=
loaded_tensors
[
_sharded_tensor_shard_id
(
x
)]
except
KeyError
as
e
:
raise
CheckpointingException
(
f
'Missing loaded tensor shard:
{
_sharded_tensor_shard_id
(
x
)
}
'
)
from
e
return
x
dict_list_map_inplace
(
fill_in_sharded_tensor
,
sharded_state_dict
)
@
property
def
can_handle_sharded_objects
(
self
):
return
self
.
base_strategy
.
can_handle_sharded_objects
def
load_tensors_metadata
(
self
,
checkpoint_dir
:
Path
):
self
.
base_strategy
.
load_tensors_metadata
(
checkpoint_dir
)
def
check_backend_compatibility
(
self
,
loaded_version
):
self
.
base_strategy
.
check_backend_compatibility
(
loaded_version
)
def
check_version_compatibility
(
self
,
loaded_version
):
self
.
base_strategy
.
check_version_compatibility
(
loaded_version
)
def
_sharded_tensor_shard_id
(
sharded_tensor
:
ShardedTensor
)
->
_ShardId
:
""" Unique id of the sharded tensor data.
Should yield the same value for same data replicated on different ranks.
Args:
sharded_tensor (ShardedTensor): sharded tensor representing the data shard
Returns (tuple): unique id of a data shard
"""
f_range
=
sharded_tensor
.
flattened_range
return
(
sharded_tensor
.
key
,
sharded_tensor
.
global_offset
,
None
if
f_range
is
None
else
(
f_range
.
start
,
f_range
.
stop
),
)
def
_shard_size
(
sh_ten
:
ShardedTensor
):
""" Returns size in bytes of a given sharded tensor. """
if
sh_ten
.
flattened_range
is
None
:
numel
=
np
.
product
(
sh_ten
.
local_shape
)
else
:
numel
=
sh_ten
.
flattened_range
.
stop
-
sh_ten
.
flattened_range
.
start
return
numel
*
torch
.
_utils
.
_element_size
(
sh_ten
.
dtype
)
def
determine_main_replica_uniform_distribution
(
sharded_state_dict
:
ShardedStateDict
,
parallelization_group
:
torch
.
distributed
.
ProcessGroup
,
is_loading
:
bool
=
False
,
)
->
Optional
[
SaveLoadDistribution
]:
""" Computes the save distribution.
Should be used in conjunction with `distribute_main_replicas_with_precomputed_distribution`
which applies the computed save distribution.
We rely on the fact that the assignment algorithm is deterministic on all ranks,
so there is no extra communication needed after metadata exchange.
Args:
sharded_state_dict (ShardedStateDict): state dict to compute the distribution of
parallelization_group (ProcessGroup): distribution will be computed
within this process group
is_loading (bool, optional): whether the distribution is for loading or saving.
For loading, even non-main replicas must be loaded by this parallelization
group. Defaults to False.
Returns (SaveLoadDistribution, optional): distribution that can be used to apply the
parallelization. Returns None if the process_group is trivial (1 rank)
"""
group_size
=
torch
.
distributed
.
get_world_size
(
group
=
parallelization_group
)
if
group_size
<=
1
:
return
local_shards
=
list
(
sh_base
for
sh_base
in
nested_values
(
sharded_state_dict
)
if
isinstance
(
sh_base
,
ShardedTensor
)
)
local_shards_no_data
=
[
ten
.
without_data
()
for
ten
in
local_shards
]
all_shards
=
[
None
]
*
torch
.
distributed
.
get_world_size
(
group
=
parallelization_group
)
torch
.
distributed
.
all_gather_object
(
all_shards
,
local_shards_no_data
,
group
=
parallelization_group
)
shard_to_ranks
=
defaultdict
(
list
)
shard_to_size
=
{}
shard_to_metadata
=
{}
shards_saved_by_this_parallelization_group
:
Set
[
_ShardId
]
=
set
()
for
rank
,
rank_shards
in
enumerate
(
all_shards
):
for
sh_ten
in
rank_shards
:
shard_id
=
_sharded_tensor_shard_id
(
sh_ten
)
shard_to_ranks
[
shard_id
].
append
(
rank
)
if
shard_id
not
in
shard_to_size
:
shard_to_size
[
shard_id
]
=
_shard_size
(
sh_ten
)
shard_to_metadata
[
shard_id
]
=
sh_ten
if
is_main_replica
(
sh_ten
.
replica_id
)
or
is_loading
:
shards_saved_by_this_parallelization_group
.
add
(
shard_id
)
shard_to_ranks
=
{
k
:
v
for
k
,
v
in
shard_to_ranks
.
items
()
if
k
in
shards_saved_by_this_parallelization_group
}
shard_to_saving_rank
=
distribute_shards_to_ranks
(
shard_to_ranks
,
shard_to_size
,
len
(
all_shards
)
)
return
SaveLoadDistribution
(
shard_to_saving_rank
,
shards_saved_by_this_parallelization_group
,
shard_to_metadata
)
def
distribute_main_replicas_with_precomputed_distribution
(
sharded_state_dict
:
ShardedStateDict
,
parallelization_group
:
torch
.
distributed
.
ProcessGroup
,
precomputed_distribution
:
Optional
[
SaveLoadDistribution
],
):
""" Applies the save distribution computed with `determine_main_replica_uniform_distribution`.
Based on rank assignment, sets replica ids of the shards saved by current rank to 0
and all the other replica ids to 1.
Args:
sharded_state_dict (ShardedStateDict): state dict to apply the save distribution to
parallelization_group (ProcessGroup): distribution will be applied within this
process group. Must match with the process group passed to
`determine_main_replica_uniform_distribution`.
precomputed_distribution (SaveLoadDistribution): distribution computed with
`determine_main_replica_uniform_distribution`
Returns: None
Example replica ids of tensors A, B, C before distribution:
rank0: A: (0, 0, 0), B: (0, 0, 0), C: (0, 0, 0)
rank1: A: (0, 0, 1), B: (0, 0, 1), C: (0, 0, 1)
rank2: A: (0, 0, 2), B: (0, 0, 2), C: (0, 0, 2)
Replicas after distribution for the example above:
rank0: A: 0, B: 1, C: 1
rank1: A: 1, B: 0, C: 1
rank2: A: 1, B: 1, C: 0
"""
if
torch
.
distributed
.
get_world_size
(
group
=
parallelization_group
)
<=
1
:
return
if
precomputed_distribution
is
None
:
raise
ValueError
(
'precomputed_distribution must be not None for non-trivial parallelization group'
)
local_shards
=
list
(
sh_base
for
sh_base
in
nested_values
(
sharded_state_dict
)
if
isinstance
(
sh_base
,
ShardedTensor
)
)
rank_within_dp_group
=
torch
.
distributed
.
get_rank
(
parallelization_group
)
for
sh_ten
in
local_shards
:
shard_id
=
_sharded_tensor_shard_id
(
sh_ten
)
if
(
shard_id
in
precomputed_distribution
.
shards_in_this_group
and
rank_within_dp_group
==
precomputed_distribution
.
main_rank_for_shard
[
shard_id
]
):
sh_ten
.
replica_id
=
0
else
:
sh_ten
.
replica_id
=
1
T
=
TypeVar
(
'T'
)
def
distribute_shards_to_ranks
(
shard_to_ranks
:
Dict
[
T
,
List
[
int
]],
shard_to_size
:
Dict
[
T
,
int
],
num_ranks
:
int
)
->
Dict
[
T
,
int
]:
""" Computes uniform distribution of workload across ranks, based on sizes.
Currently, the assignment is greedy, based on:
1. Firstly, the coverage of each shard
(how many ranks the shard is available on; lower coverage is assigned first)
2. Secondly, the size of each shard (larger size is assigned first)
3. Finally, shard id for differentiation.
Third step is added because we rely on the fact that the assignment is deterministic on all ranks.
Args:
shard_to_ranks (Dict[T, List[int]]): mapping which tells which rank have access to which shards
shard_to_size (Dict[T, int]): sizes of each shard
num_ranks (int): number of ranks in the parallelization group
Returns (Dict[T, int]): assignment of shard to rank (which rank should do the work
to achieve maximal uniformity)
"""
shard_to_ranks
=
{
k
:
tuple
(
v
)
for
k
,
v
in
shard_to_ranks
.
items
()}
shard_to_saving_rank
=
{}
rank_sizes
=
[(
0
,
rank
)
for
rank
in
range
(
num_ranks
)]
# start from tensors with lowest coverage, then go by tensor size from largest (hence minus size)
for
shard_id
,
shard_ranks
in
sorted
(
shard_to_ranks
.
items
(),
key
=
lambda
sh_id_ranks
:
(
len
(
sh_id_ranks
[
1
]),
-
shard_to_size
[
sh_id_ranks
[
0
]],
sh_id_ranks
[
0
],
),
):
# assign greedily to the least occupied rank
size
,
rank
=
min
((
size
,
rank
)
for
size
,
rank
in
rank_sizes
if
rank
in
shard_ranks
)
shard_to_saving_rank
[
shard_id
]
=
rank
rank_sizes
[
rank
]
=
(
size
+
shard_to_size
[
shard_id
],
rank
)
logger
.
debug
(
f
'distribute_shards_to_ranks distribution:
{
rank_sizes
}
'
)
return
shard_to_saving_rank
megatron/core/dist_checkpointing/strategies/state_dict_saver.py
0 → 100644
View file @
523ec9cc
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
""" State dict saver for PyT Distributed format allowing asynchronous save. """
from
logging
import
getLogger
from
time
import
time
from
typing
import
TYPE_CHECKING
,
Optional
,
Tuple
,
cast
import
torch
import
torch.distributed
as
dist
from
torch.distributed.checkpoint
import
CheckpointException
from
torch.distributed.checkpoint.default_planner
import
DefaultSavePlanner
from
torch.distributed.checkpoint.metadata
import
STATE_DICT_TYPE
,
Metadata
from
torch.distributed.checkpoint.planner
import
SavePlanner
from
torch.distributed.checkpoint.utils
import
_DistWrapper
,
_get_failure_dict
if
TYPE_CHECKING
:
from
.filesystem_async
import
FileSystemWriterAsync
logger
=
getLogger
(
__name__
)
def
save_state_dict_async_plan
(
state_dict
:
STATE_DICT_TYPE
,
storage_writer
:
'FileSystemWriterAsync'
,
process_group
:
Optional
[
dist
.
ProcessGroup
]
=
None
,
coordinator_rank
:
int
=
0
,
planner
:
Optional
[
SavePlanner
]
=
None
,
)
->
Tuple
[
'FileSystemWriterAsync'
,
Metadata
,
_DistWrapper
]:
"""
First stage of saving a state dict to storage.
This is an async adjustment of torch.distributed.checkpoint.state_dict_saver.
In order to support async save, saving should be split into three parts:
1. Planning
2. Actual saving
3. Finalization
Out of these, step (2) *must* happen asynchronously.
The first step is realized with this function.
The planning part consists of several steps, described here:
https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.SavePlanner
Args:
state_dict (STATE_DICT_TYPE): state dict to save
storage_writer (FileSystemWriterAsync): in current version only an instance of
FileSystemWriterAsync
process_group (dist.ProcessGroup, optional): process group used for save planning
coordinator_rank (int, optional): coordinator rank for planning. Defaults to 0.
planner (SavePlanner, optional): save planner for torch.distributed.checkpoint format
Returns: Tuple of:
- storage writer (the one passed as input)
- metadata from planning
- distributed wrapper used for planning
The return value of this function should be passed as an input to
`save_state_dict_async_finalize`.
"""
rank
=
torch
.
distributed
.
get_rank
()
if
torch
.
distributed
.
is_initialized
()
else
0
dist_wrapper
=
_DistWrapper
(
process_group
,
True
,
coordinator_rank
)
if
planner
is
None
:
planner
=
DefaultSavePlanner
()
assert
planner
is
not
None
global_metadata
=
None
def
local_step
():
assert
planner
is
not
None
planner
.
set_up_planner
(
state_dict
,
dist_wrapper
.
is_coordinator
)
storage_writer
.
set_up_storage_writer
(
dist_wrapper
.
is_coordinator
)
local_plan
=
planner
.
create_local_plan
()
local_plan
=
storage_writer
.
prepare_local_plan
(
local_plan
)
return
local_plan
def
global_step
(
all_local_plans
):
nonlocal
global_metadata
assert
planner
is
not
None
all_local_plans
,
global_metadata
=
planner
.
create_global_plan
(
all_local_plans
)
all_local_plans
=
storage_writer
.
prepare_global_plan
(
all_local_plans
)
return
all_local_plans
# Execute local and global planning
start_plan
=
time
()
central_plan
=
dist_wrapper
.
reduce_scatter
(
"plan"
,
local_step
,
global_step
)
logger
.
debug
(
f
"rank:
{
rank
}
, plan time:
{
time
()
-
start_plan
}
"
)
# Prepare async writing of tensors.
# The `storage_writer` will store the information about tensors it needs to save
start
=
time
()
final_local_plan
=
planner
.
finish_plan
(
central_plan
)
storage_writer
.
prepare_write_data
(
final_local_plan
,
planner
)
end
=
time
()
logger
.
debug
(
f
"
{
time
()
}
rank:
{
rank
}
, write(async) time:
{
end
-
start
}
"
)
return
storage_writer
,
cast
(
Metadata
,
global_metadata
),
dist_wrapper
def
save_state_dict_async_finalize
(
storage_writer
:
'FileSystemWriterAsync'
,
global_metadata
:
Metadata
,
dist_wrapper
:
_DistWrapper
,
)
->
None
:
"""
Finalization of save_state_dict_async_plan.
The input arguments are the same as the save_state_dict_async_plan output,
the `write_results` are retrieved from the storage_writer.
Args:
storage_writer (FileSystemWriterAsync): storage writer used for planning
global_metadata (Metadata): metadata created during planning
dist_wrapper (_DistWrapper): distributed wrapper created during planning
Returns: None
"""
write_results
=
storage_writer
.
retrieve_write_results
()
# Gather the write results that will be saved to the metadata file.
gather_start
=
time
()
all_results
=
dist_wrapper
.
gather_object
(
write_results
)
gather_end
=
time
()
logger
.
debug
(
f
"
{
gather_end
}
,
{
torch
.
distributed
.
get_rank
()
}
, gather:
{
gather_end
-
gather_start
}
"
)
# Store the metadata on coordinator rank
if
dist_wrapper
.
is_coordinator
:
node_failures
=
_get_failure_dict
(
all_results
)
if
len
(
node_failures
)
==
0
:
assert
global_metadata
is
not
None
write_start
=
time
()
storage_writer
.
finish
(
global_metadata
,
all_results
)
write_end
=
time
()
logger
.
debug
(
f
"
{
write_end
}
, metadata_write:
{
write_end
-
write_start
}
"
)
else
:
raise
CheckpointException
(
"write"
,
node_failures
)
megatron/core/dist_checkpointing/strategies/tensorstore.py
0 → 100644
View file @
523ec9cc
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Strategies using TensorStore to load and save Zarr arrays. """
from
functools
import
partial
from
itertools
import
starmap
from
logging
import
getLogger
from
pathlib
import
Path
import
tensorstore
as
ts
import
torch
from
..core
import
CheckpointingException
from
..dict_utils
import
dict_list_map_inplace
from
..mapping
import
ShardedStateDict
,
ShardedTensor
from
.base
import
LoadShardedStrategy
,
StrategyAction
,
default_strategies
from
.zarr
import
(
load_zarr_based_sharded_metadata
,
numpy_to_torch_dtype_dict
,
postprocess_numpy_array
,
)
_import_trigger
=
None
logger
=
getLogger
(
__name__
)
class
TensorStoreLoadShardedStrategy
(
LoadShardedStrategy
):
def
__init__
(
self
,
load_directly_on_device
:
bool
=
False
):
super
().
__init__
()
self
.
load_directly_on_device
=
load_directly_on_device
def
load
(
self
,
sharded_state_dict
:
ShardedStateDict
,
checkpoint_dir
:
Path
):
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
f
'Loading distributed checkpoint with
{
self
.
__class__
.
__name__
}
'
)
if
self
.
load_directly_on_device
:
print
(
f
'Loading distributed checkpoint directly on the GPU'
)
load_fn
=
partial
(
_load_from_array
,
checkpoint_dir
=
checkpoint_dir
,
load_directly_on_device
=
self
.
load_directly_on_device
,
)
dict_list_map_inplace
(
load_fn
,
sharded_state_dict
)
return
sharded_state_dict
def
load_tensors_metadata
(
self
,
checkpoint_dir
:
Path
):
def
get_ts_shape_dtype
(
path
):
arr
=
open_ts_array
(
path
)
return
arr
.
shape
,
arr
.
dtype
.
numpy_dtype
return
load_zarr_based_sharded_metadata
(
checkpoint_dir
,
get_ts_shape_dtype
)
def
check_backend_compatibility
(
self
,
loaded_version
):
pass
# TODO
def
check_version_compatibility
(
self
,
loaded_version
):
pass
# TODO
def
merge_global_slice_with_shape
(
global_slice
,
actual_shape
,
key
):
def
_merge_slice
(
dim_slice
,
dim_size
):
if
isinstance
(
dim_slice
,
slice
):
assert
(
dim_slice
.
start
<
dim_size
),
f
'Got empty slice for ShardedTensor
{
key
}
(
{
dim_slice
}
,
{
dim_size
}
)'
if
dim_slice
.
stop
>
dim_size
:
dim_slice
=
slice
(
dim_slice
.
start
,
dim_size
,
dim_slice
.
step
)
return
dim_slice
assert
len
(
global_slice
)
==
len
(
actual_shape
),
(
global_slice
,
actual_shape
,
key
)
return
tuple
(
starmap
(
_merge_slice
,
zip
(
global_slice
,
actual_shape
)))
def
_load_from_array
(
sharded_tensor
:
ShardedTensor
,
checkpoint_dir
:
Path
,
load_directly_on_device
:
bool
=
False
,
apply_flattened_range
:
bool
=
True
,
):
x
=
_load_regular_chunk
(
sharded_tensor
,
checkpoint_dir
)
ten
=
postprocess_numpy_array
(
x
,
sharded_tensor
,
apply_flattened_range
)
if
load_directly_on_device
:
sharded_tensor
.
data
.
data
.
copy_
(
ten
)
return
sharded_tensor
.
data
else
:
return
ten
def
_load_regular_chunk
(
sharded_tensor
:
ShardedTensor
,
checkpoint_dir
:
Path
):
assert
isinstance
(
sharded_tensor
,
ShardedTensor
),
type
(
sharded_tensor
)
arr
=
open_ts_array
(
checkpoint_dir
/
sharded_tensor
.
key
)
if
sharded_tensor
.
global_shape
==
arr
.
shape
:
x
=
(
arr
[
sharded_tensor
.
global_slice
()].
read
().
result
()
)
# flattened tensors loading is delayed
elif
sharded_tensor
.
allow_shape_mismatch
:
global_slice
=
merge_global_slice_with_shape
(
sharded_tensor
.
global_slice
(),
arr
.
shape
,
sharded_tensor
.
key
)
x
=
arr
[
global_slice
].
read
().
result
()
# flattened tensors loading is delayed
else
:
_msg
=
(
f
'Global shape mismatch for loaded (
{
arr
.
shape
}
)'
f
' and expected (
{
sharded_tensor
.
global_shape
}
) tensor'
f
' for key
{
sharded_tensor
.
key
}
'
)
raise
CheckpointingException
(
_msg
)
return
x
def
open_ts_array
(
arr_path
:
Path
):
"""Opens a Zarr file array with Tensorstore with basic setting.
Args:
arr_path (Path): path to a Zarr (Tensorstore) array
"""
spec
=
{
'driver'
:
'zarr'
,
'metadata_key'
:
'.zarray'
,
'kvstore'
:
{}}
spec
[
'kvstore'
]
=
{
'driver'
:
'file'
,
'path'
:
str
(
arr_path
),
}
try
:
arr
=
ts
.
open
(
ts
.
Spec
(
spec
),
open
=
True
).
result
()
except
Exception
as
e
:
raise
CheckpointingException
(
f
'Array
{
arr_path
}
could not be loaded. Error:
{
e
}
'
)
from
e
return
arr
default_strategies
[
StrategyAction
.
LOAD_SHARDED
.
value
][
(
'zarr'
,
1
)
]
=
TensorStoreLoadShardedStrategy
()
megatron/core/dist_checkpointing/strategies/torch.py
0 → 100644
View file @
523ec9cc
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Strategies using PyTorch distributed.checkpoint as an underlying format. """
import
dataclasses
import
io
import
itertools
import
math
from
collections
import
ChainMap
,
defaultdict
from
dataclasses
import
dataclass
from
itertools
import
product
from
logging
import
getLogger
from
pathlib
import
Path
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
,
cast
import
numpy
as
np
import
torch
from
torch.distributed
import
checkpoint
from
torch.distributed._shard.metadata
import
ShardMetadata
from
torch.distributed._shard.sharded_tensor
import
Shard
,
ShardedTensorMetadata
,
TensorProperties
from
torch.distributed._sharded_tensor
import
ShardedTensor
as
TorchShardedTensor
from
torch.distributed.checkpoint
import
(
DefaultLoadPlanner
,
DefaultSavePlanner
,
FileSystemReader
,
LoadPlan
,
Metadata
,
SavePlan
,
TensorStorageMetadata
,
WriteItem
,
)
from
torch.distributed.checkpoint._nested_dict
import
FLATTEN_MAPPING
,
unflatten_state_dict
from
torch.distributed.checkpoint._traverse
import
OBJ_PATH
,
traverse_state_dict
from
torch.distributed.checkpoint.default_planner
import
create_default_local_save_plan
from
torch.distributed.checkpoint.planner_helpers
import
_create_write_items
from
..core
import
CheckpointingException
from
..dict_utils
import
nested_values
from
..mapping
import
(
ShardedBase
,
ShardedObject
,
ShardedStateDict
,
ShardedTensor
,
StateDict
,
is_main_replica
,
)
from
.async_utils
import
AsyncRequest
from
.base
import
(
AsyncSaveShardedStrategy
,
LoadShardedStrategy
,
SaveShardedStrategy
,
StrategyAction
,
default_strategies
,
)
from
.filesystem_async
import
FileSystemWriterAsync
from
.state_dict_saver
import
save_state_dict_async_finalize
,
save_state_dict_async_plan
_import_trigger
=
None
logger
=
getLogger
(
__name__
)
def
flatten_state_dict
(
state_dict
:
ShardedStateDict
,
)
->
Tuple
[
ShardedStateDict
,
Dict
[
str
,
OBJ_PATH
]]:
""" Flattens state dict into a single level dict.
It's a copy of torch.distributed.checkpoint._nested_dict.flatten_state_dict
which also accepts ShardedBase tensors as terminal objects
Args:
state_dict (ShardedStateDict): state dict to be flattened
Returns (tuple): flattened state dict and a mapping allowing to recreate the original one
"""
flattened
=
{}
mappings
=
{}
def
flat_copy
(
path
:
OBJ_PATH
,
value
:
Any
)
->
None
:
new_fqn
=
"."
.
join
(
map
(
str
,
path
))
if
new_fqn
in
flattened
:
raise
ValueError
(
f
"duplicated flatten key
{
new_fqn
}
"
)
flattened
[
new_fqn
]
=
value
mappings
[
new_fqn
]
=
path
traverse_state_dict
(
state_dict
,
flat_copy
,
lambda
x
:
isinstance
(
x
,
(
torch
.
Tensor
,
ShardedBase
)))
return
flattened
,
mappings
def
sharded_tensor_to_torch_sharded_tensor
(
sh_tens
:
List
[
ShardedTensor
],
rank
:
Optional
[
int
]
=
None
)
->
TorchShardedTensor
:
"""Convert MCore ShardedTensor to PyT ShardedTensor. PyT requires information about all chunks.
On high-level, this function follows the logic of torch.distributed.fsdp._shard_utils._create_chunk_sharded_tensor.
Additionally, it saves `prepend_axis_num` and `has_flattened_range` (specific to MCore) as attributes
for further restoration in `_unwrap_pyt_sharded_tensor`.
NOTE: this function assumes regular (grid) sharding of the MCore ShardedTensor.
The only local irregularities could be introduced with a `flattened_range` attribute.
This function handles 3 different type of ShardedTensors:
1. Non-flat regular ShardedTensors (`not has_flattened_range`)
2. 1D flattened ShardedTensors (`is_flattened_range_1d`)
3. N-D flattened ShardedTensors (`has_flattened_range`)
(1) and (2) type are saved according to their original shape.
Type (3) however requires global shape adjustment for efficiency:
we treat [X, Y, Z] global shape tensor with local shape [x, y, z]
as a [X // x, Y // y, Z // z, x * y * z] tensor with last axis
partitioned according to `flattened_range` slices.
This will need special handling while resharding.
Args:
sh_tens (List[ShardedTensor]): list of sharded tensors to convert
rank (int, optional): current process rank passed to PyT ShardedTensor.
If None, assumes rank in the default pg.
Returns (TorchShardedTensor): PyT ShardedTensor containing all passed shards.
"""
if
rank
is
None
:
rank
=
torch
.
distributed
.
get_rank
()
some_sh_ten
=
sh_tens
[
0
]
has_flattened_range
=
some_sh_ten
.
flattened_range
is
not
None
is_flattened_range_1d
=
has_flattened_range
and
len
(
some_sh_ten
.
global_shape
)
==
1
for
sh_ten
in
sh_tens
:
assert
(
sh_ten
.
flattened_range
is
not
None
)
==
has_flattened_range
,
sh_tens
if
not
sh_ten
.
data
.
is_contiguous
():
sh_ten
.
data
=
sh_ten
.
data
.
contiguous
()
local_global_offsets
=
{}
prepend_axis_num
=
sh_tens
[
0
].
prepend_axis_num
# Determine local shards according to tensor type (see docs)
if
is_flattened_range_1d
:
# Type (2) case: 1D flattened ShardedTensors
for
sh_ten
in
sh_tens
:
assert
len
(
sh_ten
.
global_offset
)
==
1
,
sh_ten
assert
sh_ten
.
prepend_axis_num
==
0
,
sh_ten
local_global_offsets
.
setdefault
(
sh_ten
.
global_offset
,
[]).
append
(
sh_ten
)
global_shape
=
some_sh_ten
.
global_shape
offsets_shape
=
(
some_sh_ten
.
local_shape
)
# local shape is not flattened, we need it for chunk offsets
local_shards
=
[
Shard
.
from_tensor_and_offsets
(
sh_ten
.
data
,
[
sh_ten
.
global_offset
[
0
]
+
sh_ten
.
flattened_range
.
start
],
# additional flattened offset
rank
,
)
for
sh_ten
in
sh_tens
]
elif
has_flattened_range
:
# Type (3) case: N-D flattened ShardedTensors
for
sh_ten
in
sh_tens
:
local_global_offsets
.
setdefault
(
sh_ten
.
local_chunk_offset_in_global
(),
[]).
append
(
sh_ten
)
assert
sh_ten
.
data
.
ndim
==
1
,
sh_ten
sh_ten
.
data
=
sh_ten
.
data
.
view
((
1
,)
*
len
(
sh_ten
.
global_shape
)
+
(
-
1
,))
# Global shape reformulation:
global_shape
=
some_sh_ten
.
axis_fragmentations
+
(
int
(
np
.
prod
(
some_sh_ten
.
local_shape
)),)
offsets_shape
=
(
1
,)
*
len
(
some_sh_ten
.
global_shape
)
# reformulated global shape has shape equal ti number of local chunks
local_shards
=
[
Shard
.
from_tensor_and_offsets
(
sh_ten
.
data
,
list
(
sh_ten
.
local_chunk_offset_in_global
()
+
(
sh_ten
.
flattened_range
.
start
,)
),
# additional flattened offset
rank
,
)
for
sh_ten
in
sh_tens
]
else
:
# Type (1) case: non-flat regular ShardedTensors
for
sh_ten
in
sh_tens
:
local_global_offsets
.
setdefault
(
sh_ten
.
global_offset
,
[]).
append
(
sh_ten
)
sh_ten
.
data
=
sh_ten
.
data
.
view
(
(
1
,)
*
prepend_axis_num
+
sh_ten
.
local_shape
)
# adjust to prepended_axis_num
global_shape
=
some_sh_ten
.
global_shape
offsets_shape
=
some_sh_ten
.
data
.
shape
# includes prepended axes
local_shards
=
[
Shard
.
from_tensor_and_offsets
(
sh_ten
.
data
,
list
(
sh_ten
.
global_offset
),
rank
# simple case
)
for
sh_ten
in
sh_tens
]
# Create a ShardedTensor without invoking communication. Determine global shards
shard_metadata
=
[]
# NOTE: here we assume a regular grid of shards
for
fragment_offsets
in
itertools
.
product
(
*
map
(
range
,
some_sh_ten
.
axis_fragmentations
)):
offset
=
tuple
(
map
(
lambda
x
:
x
[
0
]
*
x
[
1
],
zip
(
fragment_offsets
,
offsets_shape
)))
if
offset
in
local_global_offsets
:
# local shard
placement
=
f
"rank:
{
rank
}
/cuda"
for
sh_ten
in
local_global_offsets
[
offset
]:
if
is_flattened_range_1d
:
offset
=
(
sh_ten
.
global_offset
[
0
]
+
sh_ten
.
flattened_range
.
start
,)
size
=
sh_ten
.
data
.
shape
elif
has_flattened_range
:
assert
offset
==
sh_ten
.
local_chunk_offset_in_global
()
# This is not an actual offset, but an offset of the whole shard
# This is needed for a PyT Dist internal integrity check
offset
=
sh_ten
.
local_chunk_offset_in_global
()
+
(
0
,)
size
=
(
1
,)
*
len
(
offsets_shape
)
+
global_shape
[
-
1
:]
else
:
size
=
sh_ten
.
data
.
shape
shard_metadata
.
append
(
ShardMetadata
(
offset
,
size
,
placement
))
else
:
# for shards from other ranks we provide simplistic data - this information will be discarded
# during TorchShardedTensor._init_from_local_shards_and_global_metadata call
if
has_flattened_range
and
not
is_flattened_range_1d
:
offset
=
offset
+
(
0
,)
size
=
(
1
,)
*
len
(
offsets_shape
)
+
global_shape
[
-
1
:]
else
:
size
=
offsets_shape
shard_metadata
.
append
(
ShardMetadata
(
offset
,
size
,
"cuda"
))
tensor
=
some_sh_ten
.
data
sharded_tensor_metadata
=
ShardedTensorMetadata
(
shards_metadata
=
shard_metadata
,
size
=
torch
.
Size
(
global_shape
),
tensor_properties
=
TensorProperties
(
dtype
=
tensor
.
dtype
,
layout
=
tensor
.
layout
,
requires_grad
=
tensor
.
requires_grad
,
memory_format
=
torch
.
contiguous_format
,
pin_memory
=
tensor
.
is_pinned
(),
),
)
pyt_sh_ten
=
TorchShardedTensor
.
_init_from_local_shards_and_global_metadata
(
local_shards
,
sharded_tensor_metadata
=
sharded_tensor_metadata
,
process_group
=
None
)
# Store MCore related data as PyTShardedTensor attribute. This won't be stored in the checkpoint, only for runtime purposes
pyt_sh_ten
.
mcore_sh_ten
=
sh_ten
.
without_data
()
pyt_sh_ten
.
mcore_metadata
=
{}
if
has_flattened_range
and
not
is_flattened_range_1d
:
pyt_sh_ten
.
mcore_metadata
[
'nd_reformulated_orig_global_shape'
]
=
sh_ten
.
global_shape
return
pyt_sh_ten
def
mcore_to_pyt_state_dict
(
state_dict
:
Dict
[
str
,
List
[
ShardedBase
]],
is_loading
:
bool
=
False
,
init_device
:
torch
.
device
=
torch
.
device
(
"cpu"
),
)
->
Dict
[
str
,
Union
[
TorchShardedTensor
,
io
.
BytesIO
]]:
"""Turn state dict with ShardedTensors and ShardedObjects to state dict compatible with PyT Dist format.
Operates in-place and returns the original state dict.
Args:
state_dict (Dict[str, List[ShardedBase]]): flattened state dict, where values
are lists of either ShardedTensor or ShardedObjects.
is_loading (bool, optional): flag indicating if loading or saving. Defaults to False.
init_device (torch.device, optional): device to initialize potentially missing tensors
during loading. Defaults to 'cpu'.
Returns (Dict[str, Union[TorchShardedTensor, io.BytesIO]]): original dictionary with values
converted either into PyT ShardedTensors or io.BytesIO.
"""
rank
=
torch
.
distributed
.
get_rank
()
pyt_state_dict
=
{}
def
_mcore_to_torch_sharded_tensor
(
sh_tens
:
List
[
ShardedTensor
])
->
TorchShardedTensor
:
"""Build a PyT ShardedTensor from given shards.
During loading:
- if data is None, initialize it with an empty tensor (will be used to copy the data into)
- if `allow_shape_mismatch` is True, the data is initialized with zeros
prior to loading (not all parts of the tensor will be read from the checkpoint)
"""
assert
all
(
isinstance
(
sh_ten
,
ShardedTensor
)
for
sh_ten
in
sh_tens
),
sh_tens
for
sh_ten
in
sh_tens
:
if
sh_ten
.
data
is
None
:
if
is_loading
:
sh_ten
.
init_data
(
init_device
,
init_fn
=
torch
.
zeros
if
sh_ten
.
allow_shape_mismatch
else
torch
.
empty
,
)
else
:
raise
CheckpointingException
(
f
'`data` attr is None for
{
sh_ten
}
'
)
else
:
sh_ten
.
data
=
sh_ten
.
data
.
detach
()
if
sh_ten
.
allow_shape_mismatch
and
is_loading
:
sh_ten
.
data
.
zero_
()
torch_sh_ten
=
sharded_tensor_to_torch_sharded_tensor
(
sh_tens
,
rank
)
torch_sh_ten
.
key
=
sh_tens
[
0
].
key
return
torch_sh_ten
def
_mcore_to_torch_sharded_object
(
sh_objs
:
List
[
ShardedObject
])
->
io
.
BytesIO
:
"""Build io.BytesIO from given sharded objects data."""
assert
all
(
isinstance
(
sh_obj
,
ShardedObject
)
for
sh_obj
in
sh_objs
),
sh_objs
serialized_data
=
io
.
BytesIO
()
torch
.
save
([
sh_obj
.
data
for
sh_obj
in
sh_objs
],
serialized_data
)
return
serialized_data
for
k
,
v
in
state_dict
.
items
():
if
isinstance
(
v
[
0
],
ShardedTensor
):
v
=
cast
(
List
[
ShardedTensor
],
v
)
pyt_state_dict
[
k
]
=
_mcore_to_torch_sharded_tensor
(
v
)
else
:
v
=
cast
(
List
[
ShardedObject
],
v
)
pyt_state_dict
[
k
]
=
_mcore_to_torch_sharded_object
(
v
)
return
pyt_state_dict
def
_unwrap_pyt_sharded_tensor
(
sh_ten
:
TorchShardedTensor
)
->
List
[
torch
.
Tensor
]:
""" Unwrap tensor from PyT ShardedTensor instance.
If `prepend_axis_num` was non-zero (which is specific to MCore ShardedTensor)
then the tensor has additional singleton dimensions which should be squeezed.
"""
mcore_sh_ten
=
sh_ten
.
mcore_sh_ten
ret_tensors
=
[]
for
sh
in
sh_ten
.
local_shards
():
ten
=
sh
.
tensor
if
mcore_sh_ten
.
flattened_range
is
not
None
:
assert
ten
.
shape
[:
-
1
]
==
(
1
,)
*
(
len
(
ten
.
shape
)
-
1
),
ten
.
shape
ten
=
ten
.
view
(
-
1
)
else
:
for
_
in
range
(
mcore_sh_ten
.
prepend_axis_num
):
ten
=
ten
.
squeeze
(
0
)
ret_tensors
.
append
(
ten
)
return
ret_tensors
def
_replace_state_dict_keys_with_sharded_keys
(
sharded_state_dict
:
ShardedStateDict
,
keep_only_main_replica
:
bool
=
False
)
->
Tuple
[
Dict
[
str
,
List
[
ShardedBase
]],
FLATTEN_MAPPING
,
Dict
[
str
,
List
[
str
]]]:
"""Group ShardedBase objects by keys and return mappings required for recreating the original dict. """
flat_sd
,
flat_mapping
=
flatten_state_dict
(
sharded_state_dict
)
rename_mapping
=
defaultdict
(
list
)
new_flat_sd
=
defaultdict
(
list
)
for
k
,
sh_base
in
flat_sd
.
items
():
assert
isinstance
(
sh_base
,
ShardedBase
),
type
(
sh_base
)
key
=
sh_base
.
unique_key
if
isinstance
(
sh_base
,
ShardedObject
)
else
sh_base
.
key
if
is_main_replica
(
sh_base
.
replica_id
)
or
not
keep_only_main_replica
:
rename_mapping
[
key
].
append
(
k
)
new_flat_sd
[
key
].
append
(
sh_base
)
return
new_flat_sd
,
flat_mapping
,
rename_mapping
def
_replace_sharded_keys_with_state_dict_keys
(
state_dict
:
Dict
[
str
,
List
[
Union
[
torch
.
Tensor
,
io
.
BytesIO
]]],
flat_mapping
:
FLATTEN_MAPPING
,
rename_mapping
:
Dict
[
str
,
List
[
str
]],
):
""" Inverse of _replace_state_dict_keys_with_sharded_keys. """
recovered_sd
=
{}
for
k
,
tensors
in
state_dict
.
items
():
assert
len
(
tensors
)
==
len
(
rename_mapping
[
k
])
for
ten
,
recovered_k
in
zip
(
tensors
,
rename_mapping
[
k
]):
recovered_sd
[
recovered_k
]
=
ten
return
unflatten_state_dict
(
recovered_sd
,
flat_mapping
)
def
_restore_dict_types
(
x
:
Union
[
dict
,
list
,
Any
],
keys_template
:
Union
[
dict
,
list
,
Any
]):
""" Recursively update `x` keys, based on `keys_template`. """
if
isinstance
(
keys_template
,
dict
):
assert
isinstance
(
x
,
dict
),
type
(
x
)
for
k
,
v
in
keys_template
.
items
():
if
not
isinstance
(
k
,
str
):
assert
str
(
k
)
in
x
,
(
k
,
x
.
keys
)
x
[
k
]
=
x
.
pop
(
str
(
k
))
_restore_dict_types
(
x
[
k
],
v
)
elif
isinstance
(
keys_template
,
list
):
assert
isinstance
(
x
,
list
),
type
(
x
)
for
x_val
,
templ_val
in
zip
(
x
,
keys_template
):
_restore_dict_types
(
x_val
,
templ_val
)
@
dataclass
(
frozen
=
True
)
class
MCoreSavePlan
(
SavePlan
):
mcore_data
:
Dict
[
str
,
Dict
[
str
,
Any
]]
=
None
# Mcore related data about each tensor
class
MCoreSavePlanner
(
DefaultSavePlanner
):
"""Differs with the default planner by saving BytesIO objects on all ranks.
In the integration of MCore with PyT Distributed format, BytesIO objects
come from ShardedObjects, which should be treated as separate objects on each rank
(not common on all ranks).
Also, the objects are already packed in io.BytesIO, so no need to redo it
in transform_object.
"""
def
__init__
(
self
,
*
args
,
nd_flattened_global_shapes
:
Optional
[
Dict
[
str
,
Tuple
[
int
,
...]]]
=
None
,
**
kwargs
,
)
->
None
:
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
nd_flattened_global_shapes
=
nd_flattened_global_shapes
or
{}
def
create_local_plan
(
self
)
->
SavePlan
:
plan
=
create_default_local_save_plan
(
self
.
state_dict
,
self
.
is_coordinator
)
self
.
_add_non_coordinator_iobytes_request
(
plan
)
if
self
.
flatten_state_dict
:
plan
=
dataclasses
.
replace
(
plan
,
planner_data
=
self
.
mappings
)
plan
=
MCoreSavePlan
(
items
=
plan
.
items
,
storage_data
=
plan
.
storage_data
,
planner_data
=
plan
.
planner_data
,
mcore_data
=
{
k
:
sh_ten
.
mcore_metadata
for
k
,
sh_ten
in
self
.
state_dict
.
items
()
if
isinstance
(
sh_ten
,
TorchShardedTensor
)
},
)
self
.
plan
=
plan
return
self
.
plan
def
create_global_plan
(
self
,
all_plans
:
List
[
MCoreSavePlan
])
->
Tuple
[
List
[
SavePlan
],
Metadata
]:
global_plan
,
metadata
=
super
().
create_global_plan
(
all_plans
)
metadata
.
mcore_data
=
dict
(
ChainMap
(
*
(
plan
.
mcore_data
for
plan
in
all_plans
)))
return
global_plan
,
metadata
def
_add_non_coordinator_iobytes_request
(
self
,
plan
):
if
self
.
is_coordinator
:
return
for
fqn
,
obj
in
self
.
state_dict
.
items
():
if
isinstance
(
obj
,
io
.
BytesIO
):
plan
.
items
.
extend
(
_create_write_items
(
fqn
,
obj
))
def
transform_object
(
self
,
write_item
:
WriteItem
,
object
:
Any
):
return
object
class
MCoreLoadPlanner
(
DefaultLoadPlanner
):
"""Adds global shape validation to the default planner.
If global shape validation can be ignored (shouldn't!), the default
load planner can be used.
"""
def
__init__
(
self
,
*
args
,
shapes_validation_sharded_tensors
:
Iterable
[
ShardedTensor
]
=
(),
**
kwargs
)
->
None
:
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
shapes_validation_sharded_tensors
=
shapes_validation_sharded_tensors
def
_validate_global_shapes
(
self
,
metadata
,
sharded_tensors
):
for
sh_ten
in
sharded_tensors
:
loaded_shape
=
metadata
.
state_dict_metadata
[
sh_ten
.
key
].
size
if
sh_ten
.
flattened_range
is
None
or
len
(
sh_ten
.
global_shape
)
==
1
:
expected_shape
=
sh_ten
.
global_shape
else
:
expected_shape
=
sh_ten
.
axis_fragmentations
+
(
int
(
np
.
prod
(
sh_ten
.
local_shape
)),)
if
loaded_shape
!=
expected_shape
:
_msg
=
(
f
'Global shape mismatch for loaded (
{
loaded_shape
}
)'
f
' and expected (
{
expected_shape
}
) tensor'
f
' for key
{
sh_ten
.
key
}
'
)
raise
CheckpointingException
(
_msg
)
def
create_local_plan
(
self
)
->
LoadPlan
:
self
.
_validate_global_shapes
(
self
.
metadata
,
self
.
shapes_validation_sharded_tensors
)
return
super
().
create_local_plan
()
class
TorchDistSaveShardedStrategy
(
AsyncSaveShardedStrategy
):
"""Async save strategy for the PyT Distributed format.
The idea is to translate MCore ShardedTensors into PyT ShardedTensors
and use the async-adjusted torch.distributed.checkpoint saving mechanism
provided by the FileSystemWriterAsync writer.
"""
def
__init__
(
self
,
backend
:
str
,
version
:
int
,
keep_only_main_replica
:
bool
=
True
,
thread_count
:
int
=
2
):
"""Adds parameters specific to PyT Distributed format
Args:
backend (str): format backend string
version (int): format version
keep_only_main_replica (bool, optional): PyT Distributed has a mechanism
for deduplication, but replica_id aware deduplication is more coherent.
Default is True (recommended to keep it).
thread_count (int, optional): threads to use during saving.
Affects the number of files in the checkpoint (saving ranks * num_threads).
"""
super
().
__init__
(
backend
,
version
)
self
.
keep_only_main_replica
=
keep_only_main_replica
self
.
thread_count
=
thread_count
def
async_save
(
self
,
sharded_state_dict
:
ShardedStateDict
,
checkpoint_dir
:
Path
)
->
AsyncRequest
:
""" Translates MCore ShardedTensors to PyT ShardedTensors and saves in PyT Distributed format.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to save
checkpoint_dir (Path): checkpoint directory
Returns: None
"""
# Translate the state dict
(
sharded_state_dict
,
flat_mapping
,
rename_mapping
,
)
=
_replace_state_dict_keys_with_sharded_keys
(
sharded_state_dict
,
self
.
keep_only_main_replica
)
pyt_state_dict
=
mcore_to_pyt_state_dict
(
sharded_state_dict
,
False
)
# Use PyT saving mechanism
writer
=
FileSystemWriterAsync
(
checkpoint_dir
,
thread_count
=
self
.
thread_count
)
save_state_dict_ret
=
save_state_dict_async_plan
(
pyt_state_dict
,
writer
,
None
,
planner
=
MCoreSavePlanner
(
dedup_replicated_tensors
=
not
self
.
keep_only_main_replica
),
)
return
self
.
_get_save_and_finalize_callbacks
(
writer
,
save_state_dict_ret
)
def
_get_save_and_finalize_callbacks
(
self
,
writer
,
save_state_dict_ret
)
->
AsyncRequest
:
save_fn_args
=
writer
.
get_save_function_and_args
()
save_fn
,
save_args
=
save_fn_args
def
finalize_fn
():
save_state_dict_async_finalize
(
*
save_state_dict_ret
)
torch
.
distributed
.
barrier
()
return
AsyncRequest
(
save_fn
,
save_args
,
[
finalize_fn
])
def
can_handle_sharded_objects
(
self
):
return
True
class
TorchDistLoadShardedStrategy
(
LoadShardedStrategy
):
"""Basic load strategy for the PyT Distributed format. """
def
load
(
self
,
sharded_state_dict
:
ShardedStateDict
,
checkpoint_dir
:
Path
)
->
StateDict
:
"""Translates MCore ShardedTensors to PyT ShardedTensors and loads from PyT Distributed format.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict with mapping
information to instruct loading
checkpoint_dir (Path): checkpoint directory
Returns: loaded state dict
"""
flexible_shape_sharded_tensors
=
[
sh_ten
for
sh_ten
in
nested_values
(
sharded_state_dict
)
if
isinstance
(
sh_ten
,
ShardedTensor
)
and
not
sh_ten
.
allow_shape_mismatch
]
orig_sharded_state_dict
=
sharded_state_dict
# MCore state dict to PyT Distributed compatible
(
sharded_state_dict
,
flat_mapping
,
rename_mapping
,
)
=
_replace_state_dict_keys_with_sharded_keys
(
sharded_state_dict
)
pyt_state_dict
=
mcore_to_pyt_state_dict
(
sharded_state_dict
,
True
)
# Load PyT Distributed format
checkpoint
.
load_state_dict
(
pyt_state_dict
,
FileSystemReader
(
checkpoint_dir
),
planner
=
MCoreLoadPlanner
(
shapes_validation_sharded_tensors
=
flexible_shape_sharded_tensors
),
)
pyt_state_dict
=
cast
(
Dict
[
str
,
Union
[
TorchShardedTensor
,
List
[
io
.
BytesIO
]]],
pyt_state_dict
)
# Unwrap ShardedTensors and return to original state dict
mcore_state_dict
=
{
k
:
v
if
not
isinstance
(
v
,
TorchShardedTensor
)
else
_unwrap_pyt_sharded_tensor
(
v
)
for
k
,
v
in
pyt_state_dict
.
items
()
}
mcore_state_dict
=
_replace_sharded_keys_with_state_dict_keys
(
mcore_state_dict
,
flat_mapping
,
rename_mapping
)
_restore_dict_types
(
mcore_state_dict
,
orig_sharded_state_dict
)
return
mcore_state_dict
def
load_tensors_metadata
(
self
,
checkpoint_dir
:
Path
):
"""Uses tensors metadata stored in the metadata file."""
fs_reader
=
FileSystemReader
(
checkpoint_dir
)
metadata
=
fs_reader
.
read_metadata
()
mcore_data
=
getattr
(
metadata
,
'mcore_data'
,
{})
sharded_metadata
=
{}
for
k
,
tp
in
metadata
.
state_dict_metadata
.
items
():
if
not
isinstance
(
tp
,
TensorStorageMetadata
):
continue
# load only tensors
nd_orig_global_shape
=
mcore_data
.
get
(
k
,
{}).
get
(
'nd_reformulated_orig_global_shape'
)
if
nd_orig_global_shape
is
None
:
# Regular tensor
sharded_metadata
[
k
]
=
ShardedTensor
.
from_rank_offsets
(
k
,
torch
.
empty
(
tp
.
size
,
**
tp
.
properties
.
__dict__
,
device
=
'meta'
),
).
without_data
()
else
:
# N-D flattened tensor
unflat_ten
=
torch
.
empty
(
nd_orig_global_shape
,
**
tp
.
properties
.
__dict__
,
device
=
'meta'
)
flat_ten
=
unflat_ten
.
flatten
()
sharded_metadata
[
k
]
=
ShardedTensor
.
from_rank_offsets_flat
(
k
,
flat_ten
,
unflat_ten
.
shape
,
flattened_range
=
slice
(
0
,
unflat_ten
.
numel
()),
# whole slice
).
without_data
()
return
sharded_metadata
def
can_handle_sharded_objects
(
self
):
return
True
def
check_backend_compatibility
(
self
,
loaded_version
):
pass
# TODO
def
check_version_compatibility
(
self
,
loaded_version
):
pass
# TODO
default_strategies
[
StrategyAction
.
LOAD_SHARDED
.
value
][
(
'torch_dist'
,
1
)
]
=
TorchDistLoadShardedStrategy
()
default_strategies
[
StrategyAction
.
SAVE_SHARDED
.
value
][
(
'torch_dist'
,
1
)
]
=
TorchDistSaveShardedStrategy
(
'torch_dist'
,
1
)
megatron/core/dist_checkpointing/strategies/two_stage.py
0 → 100644
View file @
523ec9cc
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" 2-stage checkpoint loading. """
import
os
import
time
from
collections
import
defaultdict
from
dataclasses
import
dataclass
from
functools
import
partial
,
wraps
from
itertools
import
chain
from
logging
import
DEBUG
,
INFO
,
StreamHandler
,
getLogger
from
operator
import
attrgetter
,
itemgetter
from
pathlib
import
Path
from
typing
import
Iterable
,
List
,
NamedTuple
,
Optional
,
Tuple
,
Union
import
torch
from
..dict_utils
import
dict_list_map_inplace
,
map_reduce
,
nested_values
from
..mapping
import
ShardedStateDict
,
ShardedTensor
,
StateDict
from
.base
import
LoadShardedStrategy
from
.tensorstore
import
TensorStoreLoadShardedStrategy
,
_load_from_array
,
open_ts_array
from
.zarr
import
flatten_range
,
load_zarr_based_sharded_metadata
_import_trigger
=
None
timers
=
defaultdict
(
list
)
logger
=
getLogger
(
__name__
)
def
timed
(
verbose
=
True
):
def
timed_dec
(
fn
):
name
=
fn
.
__name__
@
wraps
(
fn
)
def
wrapped
(
*
args
,
**
kwargs
):
if
verbose
:
logger
.
debug
(
f
'
{
name
}
init'
)
start
=
time
.
time
()
ret
=
fn
(
*
args
,
**
kwargs
)
took
=
time
.
time
()
-
start
if
verbose
:
logger
.
debug
(
f
'
{
name
}
took
{
took
}
s'
)
timers
[
name
].
append
(
took
)
return
ret
return
wrapped
return
timed_dec
@
dataclass
class
_ShardedTensorMetadata
:
global_rank
:
int
sharded_tensor_no_data
:
ShardedTensor
dist_group_rank
:
Tuple
[
int
]
# id of distributed group
dist_group_ranks
:
Tuple
[
int
]
# id of distributed group
data_size
:
Optional
[
int
]
=
None
# bytes
def
sharded_tensor_chunk_id
(
sharded_tensor
:
ShardedTensor
):
return
(
sharded_tensor
.
key
,
sharded_tensor
.
global_offset
,
)
class
TwoStageDataParallelLoadShardedStrategy
(
LoadShardedStrategy
):
"""Loads one checkpoint replica from storage and broadcasts to other nodes.
This strategy loads checkpoint from storage on minimal set of nodes
and distributes the checkpoint to other nodes with torch.distributed.
Loading is performed with tensorstore.
Steps:
0. (optional) create Gloo distributed groups
1. Exchange ShardedTensors metadata between all nodes
2. Align needed tensors within DP groups
3. For each globally unique tensor:
3.a) on one of the ranks load it from storage to CPU and move to CUDA
3.b) allocate CUDA tensor on other ranks
3.c) broadcast within DP group
3.d) copy tensor content to the model param location
3.e) free tensor buffers from a) and b)
Notes:
1. Loading and broadcasting is done sequentially to avoid both host and device OOMs
2. There is a lot of overlap potential between all three steps done for each tensor:
2.a) loading from storage to numpy
2.b) moving CPU tensors to CUDA
2.c) broadcast
"""
def
__init__
(
self
,
data_parallel_group
,
cpu_transfer
=
True
):
super
().
__init__
()
self
.
cpu_transfer
=
cpu_transfer
self
.
data_parallel_group_orig
=
data_parallel_group
self
.
data_parallel_group
=
None
if
cpu_transfer
else
data_parallel_group
self
.
dp_group_ranks
=
tuple
(
sorted
(
torch
.
distributed
.
get_process_group_ranks
(
data_parallel_group
))
)
self
.
dp_group_rank
=
torch
.
distributed
.
get_rank
(
self
.
data_parallel_group_orig
)
self
.
global_rank
=
torch
.
distributed
.
get_rank
()
def
load
(
self
,
sharded_state_dict
:
ShardedStateDict
,
checkpoint_dir
:
Path
):
self
.
maybe_init_gloo_group
()
all_tensors_sorted
=
self
.
_build_load_plan
(
sharded_state_dict
)
self
.
_exchange_loaded_tensors
(
all_tensors_sorted
,
sharded_state_dict
,
checkpoint_dir
)
# TODO: fix hang in summarize_load_times
# self.summarize_load_times()
return
sharded_state_dict
def
summarize_load_times
(
self
):
torch
.
distributed
.
barrier
()
logger
.
info
(
'Checkpoint loading finished. Summary:'
)
# TODO: `timers` keys are not guaranteed to be the same across ranks which causes hangs
for
key
,
times
in
sorted
(
timers
.
items
()):
times_sum
=
sum
(
times
)
max_times
=
torch
.
tensor
([
times_sum
],
device
=
'cuda'
)
avg_times
=
torch
.
tensor
([
times_sum
],
device
=
'cuda'
)
torch
.
distributed
.
all_reduce
(
max_times
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
)
torch
.
distributed
.
all_reduce
(
avg_times
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
)
avg_times
/=
torch
.
distributed
.
get_world_size
()
if
torch
.
distributed
.
get_rank
()
==
0
:
logger
.
info
(
f
'
{
key
}
: max
{
max_times
[
0
]
}
, avg
{
avg_times
[
0
]
}
'
)
@
timed
(
verbose
=
False
)
def
load_tensor_from_storage
(
self
,
checkpoint_dir
,
ten_meta
:
_ShardedTensorMetadata
):
logger
.
debug
(
f
'_load_from_array(
{
ten_meta
.
sharded_tensor_no_data
.
key
}
) init'
)
ret
=
_load_from_array
(
ten_meta
.
sharded_tensor_no_data
,
checkpoint_dir
,
load_directly_on_device
=
False
,
apply_flattened_range
=
False
,
)
logger
.
debug
(
f
'_load_from_array(
{
ten_meta
.
sharded_tensor_no_data
.
key
}
) DONE'
)
return
ret
@
timed
()
def
maybe_init_gloo_group
(
self
):
if
not
self
.
cpu_transfer
:
return
all_groups
=
[
None
]
*
torch
.
distributed
.
get_world_size
()
torch
.
distributed
.
all_gather_object
(
all_groups
,
self
.
dp_group_ranks
)
all_groups
=
set
(
tuple
(
sorted
(
gr
))
for
gr
in
all_groups
)
for
group_ranks
in
sorted
(
all_groups
):
gloo_pg
=
torch
.
distributed
.
new_group
(
ranks
=
group_ranks
,
backend
=
'gloo'
)
if
self
.
global_rank
in
group_ranks
:
self
.
data_parallel_group
=
gloo_pg
assert
self
.
dp_group_rank
==
torch
.
distributed
.
get_rank
(
self
.
data_parallel_group
)
def
check_backend_compatibility
(
self
,
loaded_version
):
pass
# TODO
def
check_version_compatibility
(
self
,
loaded_version
):
pass
# TODO
@
timed
()
def
_build_load_plan
(
self
,
sharded_state_dict
:
ShardedStateDict
)
->
List
[
_ShardedTensorMetadata
]:
local_meta
=
[
_ShardedTensorMetadata
(
self
.
global_rank
,
sharded_ten
.
without_data
(),
self
.
dp_group_rank
,
self
.
dp_group_ranks
,
)
for
sharded_ten
in
nested_values
(
sharded_state_dict
)
]
all_meta
=
[
None
]
*
torch
.
distributed
.
get_world_size
(
group
=
self
.
data_parallel_group
)
torch
.
distributed
.
all_gather_object
(
all_meta
,
local_meta
,
group
=
self
.
data_parallel_group
)
all_meta
=
list
(
chain
.
from_iterable
(
all_meta
))
all_tensors_sorted
=
self
.
deduplicate_chunks
(
all_meta
)
return
all_tensors_sorted
@
timed
()
def
deduplicate_chunks
(
self
,
ten_metas
:
List
[
_ShardedTensorMetadata
]):
""" Group tensors by chunk and then pick the tensor with the lowest rank.
NOTE: with proper loading overlap, loading from randomized ranks
(instead of the smallest one) could be beneficial here.
"""
ten_metas
=
map_reduce
(
ten_metas
,
key_fn
=
lambda
meta
:
sharded_tensor_chunk_id
(
meta
.
sharded_tensor_no_data
),
reduce_fn
=
partial
(
min
,
key
=
attrgetter
(
'dist_group_rank'
)),
)
all_metas_sorted
=
list
(
map
(
itemgetter
(
1
),
sorted
(
ten_metas
.
items
())))
return
all_metas_sorted
@
timed
()
def
_exchange_loaded_tensors
(
self
,
ten_metas
:
List
[
_ShardedTensorMetadata
],
sharded_state_dict
,
checkpoint_dir
):
logger
.
debug
(
f
'_exchange_loaded_tensors, num ten_metas:
{
len
(
ten_metas
)
}
'
)
for
ten_meta
in
ten_metas
:
src_rank
=
torch
.
distributed
.
get_global_rank
(
self
.
data_parallel_group
,
ten_meta
.
dist_group_rank
)
if
self
.
dp_group_rank
==
ten_meta
.
dist_group_rank
:
exchange_tensor
=
self
.
load_tensor_from_storage
(
checkpoint_dir
,
ten_meta
)
if
not
self
.
cpu_transfer
:
exchange_tensor
=
exchange_tensor
.
cuda
()
else
:
# TODO: for non-flattened ranges we could reuse the buffer from the start here
exchange_tensor
=
torch
.
empty
(
ten_meta
.
sharded_tensor_no_data
.
local_shape
,
device
=
'cpu'
if
self
.
cpu_transfer
else
'cuda'
,
dtype
=
ten_meta
.
sharded_tensor_no_data
.
dtype
,
)
logger
.
debug
(
f
'exchange
{
ten_meta
.
sharded_tensor_no_data
.
key
}
,
{
exchange_tensor
.
shape
}
(
{
exchange_tensor
.
numel
()
}
), broadcast(
{
src_rank
}
->
{
self
.
dp_group_ranks
}
)'
)
torch
.
distributed
.
broadcast
(
exchange_tensor
,
group
=
self
.
data_parallel_group
,
src
=
src_rank
)
self
.
_distribute_data_to_state_dict
(
ten_meta
,
exchange_tensor
,
sharded_state_dict
)
logger
.
debug
(
f
'exchange
{
ten_meta
.
sharded_tensor_no_data
.
key
}
done'
)
# free buffer memory
exchange_tensor
=
None
@
timed
(
verbose
=
False
)
def
_distribute_data_to_state_dict
(
self
,
ten_meta
:
_ShardedTensorMetadata
,
loaded_ten
:
torch
.
Tensor
,
sharded_state_dict
:
ShardedStateDict
,
):
tensor_key
=
sharded_tensor_chunk_id
(
ten_meta
.
sharded_tensor_no_data
)
def
_fill_in_data
(
t
:
Union
[
ShardedTensor
,
torch
.
Tensor
]):
if
not
isinstance
(
t
,
ShardedTensor
)
or
sharded_tensor_chunk_id
(
t
)
!=
tensor_key
:
# already filled-in or key not matching
return
t
sharded_tensor
:
ShardedTensor
=
t
x
=
loaded_ten
if
sharded_tensor
.
flattened_range
is
not
None
:
x
=
flatten_range
(
sharded_tensor
,
x
)
# Reuse existing buffer
sharded_tensor
.
data
.
data
.
copy_
(
x
)
return
sharded_tensor
.
data
dict_list_map_inplace
(
_fill_in_data
,
sharded_state_dict
)
def
load_tensors_metadata
(
self
,
checkpoint_dir
:
Path
):
def
get_ts_shape_dtype
(
path
):
arr
=
open_ts_array
(
path
)
return
arr
.
shape
,
arr
.
dtype
.
numpy_dtype
return
load_zarr_based_sharded_metadata
(
checkpoint_dir
,
get_ts_shape_dtype
)
megatron/core/dist_checkpointing/strategies/zarr.py
0 → 100644
View file @
523ec9cc
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Strategies using Zarr as an underlying format. """
import
logging
import
os
import
threading
from
functools
import
partial
from
logging
import
getLogger
from
pathlib
import
Path
from
typing
import
Callable
,
List
,
Optional
,
Tuple
import
numpy
as
np
import
torch
import
zarr
from
..core
import
CheckpointingException
from
..dict_utils
import
dict_list_map_inplace
,
nested_values
from
..mapping
import
ShardedStateDict
,
ShardedTensor
,
is_main_replica
from
.base
import
LoadShardedStrategy
,
SaveShardedStrategy
,
StrategyAction
,
default_strategies
logger
=
logging
.
getLogger
(
__name__
)
numpy_to_torch_dtype_dict
=
{
np
.
dtype
(
'bool'
):
torch
.
bool
,
np
.
dtype
(
'uint8'
):
torch
.
uint8
,
np
.
dtype
(
'int8'
):
torch
.
int8
,
np
.
dtype
(
'int16'
):
torch
.
int16
,
np
.
dtype
(
'int32'
):
torch
.
int32
,
np
.
dtype
(
'int64'
):
torch
.
int64
,
np
.
dtype
(
'float16'
):
torch
.
float16
,
np
.
dtype
(
'float32'
):
torch
.
float32
,
np
.
dtype
(
'float64'
):
torch
.
float64
,
np
.
dtype
(
'complex64'
):
torch
.
complex64
,
np
.
dtype
(
'complex128'
):
torch
.
complex128
,
}
torch_to_numpy_dtype_dict
=
{
v
:
k
for
k
,
v
in
numpy_to_torch_dtype_dict
.
items
()}
try
:
import
tensorstore
HAS_BFLOAT16
=
True
numpy_to_torch_dtype_dict
[
np
.
dtype
(
'bfloat16'
)]
=
torch
.
bfloat16
torch_to_numpy_dtype_dict
[
torch
.
bfloat16
]
=
np
.
dtype
(
'bfloat16'
)
except
ImportError
:
HAS_BFLOAT16
=
False
_import_trigger
=
None
logger
=
getLogger
(
__name__
)
class
ZarrSaveShardedStrategy
(
SaveShardedStrategy
):
def
save
(
self
,
sharded_state_dict
:
ShardedStateDict
,
checkpoint_dir
:
Path
):
sharded_tensors
=
list
(
nested_values
(
sharded_state_dict
))
arrays
=
_create_or_open_zarr_arrays
(
sharded_tensors
,
checkpoint_dir
)
for
ten
,
arr
in
zip
(
sharded_tensors
,
arrays
):
_save_to_existing_array
(
ten
,
arr
)
torch
.
distributed
.
barrier
()
def
_create_or_open_zarr_arrays
(
sharded_tensors
:
List
[
ShardedTensor
],
checkpoint_dir
:
Path
)
->
List
[
Optional
[
zarr
.
Array
]]:
""" Returns list of zarr arrays corresponding to given tensors.
For a sharded tensors that:
a) is main replica and represents the first chunk (all offsets 0), creates the Zarr array
b) is main replica but not the first chunk, opens the arrays created in (a) (possibly by other process)
c) otherwise, sets the corresponding array to None since it won't be used
Args:
sharded_tensors (List[ShardedTensor]): sharded tensors from a given rank that will be saved to checkpoint
checkpoint_dir (Path): checkpoint in which the arrays will be created
"""
arrays
=
[]
for
ten
in
sharded_tensors
:
arr
=
_create_zarr_array
(
ten
,
checkpoint_dir
)
if
_should_create_array
(
ten
)
else
None
arrays
.
append
(
arr
)
torch
.
distributed
.
barrier
()
# Open arrays created above by other processes
for
arr_idx
,
ten
in
enumerate
(
sharded_tensors
):
if
arrays
[
arr_idx
]
is
not
None
:
# array created by this process
assert
_should_create_array
(
ten
),
ten
continue
if
not
is_main_replica
(
ten
.
replica_id
):
# this array won't be needed for saving and can stay None
continue
open_kwargs
=
{}
if
ten
.
flattened_range
is
not
None
:
open_kwargs
[
'synchronizer'
]
=
zarr
.
ProcessSynchronizer
(
str
(
checkpoint_dir
/
f
'
{
ten
.
key
}
.sync'
)
)
arrays
[
arr_idx
]
=
_open_zarr_array_verbose
(
checkpoint_dir
/
ten
.
key
,
'r+'
,
**
open_kwargs
)
return
arrays
def
_should_create_array
(
ten
:
ShardedTensor
):
return
(
is_main_replica
(
ten
.
replica_id
)
and
set
(
ten
.
global_offset
)
==
{
0
}
and
(
ten
.
flattened_range
is
None
or
ten
.
flattened_range
.
start
==
0
)
)
def
_save_to_existing_array
(
sharded_tensor
:
ShardedTensor
,
arr
:
Optional
[
zarr
.
Array
]):
if
not
is_main_replica
(
sharded_tensor
.
replica_id
):
return
assert
arr
is
not
None
x
=
sharded_tensor
.
data
x
=
x
.
detach
().
cpu
()
torch
.
cuda
.
synchronize
()
if
x
.
dtype
==
torch
.
bfloat16
:
x
=
x
.
float
()
x
=
x
.
numpy
()
x
=
x
.
astype
(
'bfloat16'
)
else
:
x
=
x
.
numpy
()
if
sharded_tensor
.
flattened_range
is
None
:
arr
[
sharded_tensor
.
global_slice
()]
=
x
else
:
arr
.
set_coordinate_selection
(
sharded_tensor
.
global_coordinates
(),
x
)
def
_create_zarr_array
(
sharded_tensor
:
ShardedTensor
,
checkpoint_dir
:
Path
):
np_dtype
=
torch_to_numpy_dtype_dict
[
sharded_tensor
.
dtype
]
try
:
arr
=
zarr
.
create
(
sharded_tensor
.
global_shape
,
dtype
=
np_dtype
,
store
=
checkpoint_dir
/
sharded_tensor
.
key
,
chunks
=
sharded_tensor
.
max_allowed_chunks
(),
compressor
=
None
,
fill_value
=
None
,
write_empty_chunks
=
True
,
)
logger
.
debug
(
f
'Created a new Zarr array at
{
checkpoint_dir
/
sharded_tensor
.
key
}
'
)
except
zarr
.
errors
.
ContainsArrayError
as
e
:
raise
CheckpointingException
(
f
'Array
{
checkpoint_dir
/
sharded_tensor
.
key
}
already exists'
)
from
e
if
HAS_BFLOAT16
and
np_dtype
==
np
.
dtype
(
'bfloat16'
):
arr
.
_dtype
=
np_dtype
zarray
=
arr
.
store
[
'.zarray'
]
arr
.
store
[
'.zarray'
]
=
zarray
.
replace
(
b
'<V2'
,
b
'bfloat16'
)
return
arr
class
ZarrLoadShardedStrategy
(
LoadShardedStrategy
):
def
load
(
self
,
sharded_state_dict
:
ShardedStateDict
,
checkpoint_dir
:
Path
):
dict_list_map_inplace
(
partial
(
_load_from_array
,
checkpoint_dir
=
checkpoint_dir
),
sharded_state_dict
)
return
sharded_state_dict
def
load_tensors_metadata
(
self
,
checkpoint_dir
:
Path
):
def
get_zarr_shape_dtype
(
path
):
arr
=
zarr
.
open
(
path
,
'r'
)
return
arr
.
shape
,
arr
.
dtype
return
load_zarr_based_sharded_metadata
(
checkpoint_dir
,
get_zarr_shape_dtype
)
def
check_backend_compatibility
(
self
,
loaded_version
):
pass
# TODO
def
check_version_compatibility
(
self
,
loaded_version
):
pass
# TODO
def
_load_from_array
(
sharded_tensor
:
ShardedTensor
,
checkpoint_dir
:
Path
):
assert
isinstance
(
sharded_tensor
,
ShardedTensor
),
type
(
sharded_tensor
)
arr
=
_open_zarr_array_verbose
(
checkpoint_dir
/
sharded_tensor
.
key
,
'r'
)
if
not
sharded_tensor
.
allow_shape_mismatch
and
sharded_tensor
.
global_shape
!=
arr
.
shape
:
_msg
=
(
f
'Global shape mismatch for loaded (
{
arr
.
shape
}
)'
f
' and expected (
{
sharded_tensor
.
global_shape
}
) tensor'
f
' for key
{
sharded_tensor
.
key
}
'
)
raise
CheckpointingException
(
_msg
)
x
=
arr
[
sharded_tensor
.
global_slice
()]
# flattened tensors loading is delayed
return
postprocess_numpy_array
(
x
,
sharded_tensor
)
def
_open_zarr_array_verbose
(
path
:
Path
,
mode
:
str
,
**
open_kwargs
):
try
:
return
zarr
.
open
(
str
(
path
),
mode
,
**
open_kwargs
)
except
zarr
.
errors
.
PathNotFoundError
as
e
:
ckpt_dir
=
path
.
parent
err_msg
=
f
'Array
{
path
}
not found'
if
ckpt_dir
.
exists
():
ckpt_files
=
[
f
.
name
for
f
in
ckpt_dir
.
iterdir
()]
logger
.
debug
(
f
'
{
err_msg
}
. Checkpoint directory
{
ckpt_dir
}
content:
{
ckpt_files
}
'
)
else
:
err_msg
+=
f
'. Checkpoint directory
{
ckpt_dir
}
does not exist.'
raise
CheckpointingException
(
err_msg
)
from
e
def
postprocess_numpy_array
(
loaded_array
,
sharded_tensor
,
apply_flattened_range
=
True
):
x
=
loaded_array
if
HAS_BFLOAT16
and
x
.
dtype
==
np
.
dtype
(
'bfloat16'
):
x
=
x
.
astype
(
np
.
dtype
(
'float32'
))
x
=
torch
.
from_numpy
(
x
)
x
=
x
.
bfloat16
()
else
:
x
=
torch
.
from_numpy
(
x
)
# TODO: consider some other consistency checks
if
x
.
shape
!=
sharded_tensor
.
local_shape
:
if
sharded_tensor
.
allow_shape_mismatch
:
x
=
pad_to_expected_shape
(
x
,
sharded_tensor
)
else
:
_msg
=
(
f
'Local shape mismatch for loaded (
{
x
.
shape
}
)'
f
' and expected (
{
sharded_tensor
.
local_shape
}
) tensor'
f
' for key
{
sharded_tensor
.
key
}
'
)
raise
CheckpointingException
(
_msg
)
if
apply_flattened_range
and
sharded_tensor
.
flattened_range
is
not
None
:
x
=
flatten_range
(
sharded_tensor
,
x
)
# TODO: consider cuda() tensors support
return
x
def
flatten_range
(
sharded_tensor
,
x
):
return
x
.
flatten
()[
sharded_tensor
.
flattened_range
]
def
pad_to_expected_shape
(
x
:
torch
.
Tensor
,
expected_sharded_ten
:
ShardedTensor
):
pad_args
=
[]
assert
len
(
x
.
shape
)
==
len
(
expected_sharded_ten
.
local_shape
)
# Reversed iteration order because F.pad expects so
for
x_sh
,
exp_sh
,
axis_fragm
in
reversed
(
list
(
zip
(
x
.
shape
,
expected_sharded_ten
.
local_shape
,
expected_sharded_ten
.
axis_fragmentations
)
)
):
if
x_sh
==
exp_sh
:
pad_args
.
extend
((
0
,
0
))
elif
x_sh
>
exp_sh
:
assert
(
False
),
f
'Expected shape (
{
exp_sh
}
) smaller than actual (
{
x_sh
}
) for
{
repr
(
expected_sharded_ten
)
}
'
else
:
pad_args
.
extend
((
0
,
exp_sh
-
x_sh
))
# TODO: behavior control with envvar is for testing purposes only, remove it
if
not
int
(
os
.
environ
.
get
(
'DIST_CKPT_PAD_REPLICATE'
,
0
)):
return
torch
.
nn
.
functional
.
pad
(
x
,
pad_args
)
# unsqueeze and squeeze to get shapes supported by cudnn
print
(
f
'Replicating last row for
{
expected_sharded_ten
.
key
}
'
)
if
x
.
dtype
==
torch
.
bfloat16
:
return
(
torch
.
nn
.
functional
.
pad
(
x
.
float
().
unsqueeze
(
0
),
pad_args
,
mode
=
'replicate'
)
.
squeeze
(
0
)
.
bfloat16
()
)
return
torch
.
nn
.
functional
.
pad
(
x
.
unsqueeze
(
0
),
pad_args
,
mode
=
'replicate'
).
squeeze
(
0
)
def
load_zarr_based_sharded_metadata
(
checkpoint_dir
:
Path
,
get_shape_dtype_fn
:
Callable
[[
str
],
Tuple
[
Tuple
[
int
],
np
.
dtype
]]
)
->
ShardedStateDict
:
"""Load metadata of Zarr arrays.
Args:
checkpoint_dir (str): checkpoint root directory
get_shape_dtype_fn (str -> ((int, ...), np.dtype)): a function returning
an array shape and dtype for a given Zarr array path
"""
sharded_state_dict
=
{}
for
subdir
in
checkpoint_dir
.
iterdir
():
if
not
subdir
.
is_dir
()
or
not
(
subdir
/
'.zarray'
).
exists
():
continue
key
=
subdir
.
name
arr_shape
,
arr_dtype
=
get_shape_dtype_fn
(
str
(
subdir
))
sharded_state_dict
[
key
]
=
ShardedTensor
(
key
,
None
,
numpy_to_torch_dtype_dict
[
arr_dtype
],
arr_shape
,
arr_shape
,
tuple
(
0
for
_
in
arr_shape
),
tuple
(
1
for
_
in
arr_shape
),
)
return
sharded_state_dict
# default_strategies[StrategyAction.LOAD_SHARDED.value][('zarr', 1)] = ZarrLoadShardedStrategy()
default_strategies
[
StrategyAction
.
SAVE_SHARDED
.
value
][(
'zarr'
,
1
)]
=
ZarrSaveShardedStrategy
(
'zarr'
,
1
)
megatron/core/dist_checkpointing/utils.py
0 → 100644
View file @
523ec9cc
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Helpers for manipulating sharded tensors and sharded state dicts. """
from
typing
import
Dict
,
Tuple
from
.dict_utils
import
dict_list_map_inplace
,
extract_matching_values
from
.mapping
import
(
LocalNonpersitentObject
,
ShardedBase
,
ShardedObject
,
ShardedStateDict
,
ShardedTensor
,
ShardedTensorFactory
,
StateDict
,
)
def
extract_sharded_tensors
(
sharded_state_dict
:
ShardedStateDict
,
)
->
Tuple
[
ShardedStateDict
,
StateDict
]:
""" Extract a dict consisting of only ShardedTensor objects from a given state dict with any objects.
Args:
sharded_state_dict: state dict possibly containing ShardedTensor objects
Returns:
Tuple[ShardedStateDict, StateDict]: tuple of:
- state dict with all ShardedTensor (keeping the original state dict structure)
- state dict with all objects other than ShardedTensor (keeping the original state dict structure)
"""
return
extract_matching_values
(
sharded_state_dict
,
lambda
v
:
isinstance
(
v
,
ShardedTensor
))
def
extract_sharded_tensors_and_factories
(
sharded_state_dict
:
ShardedStateDict
,
)
->
Tuple
[
ShardedStateDict
,
StateDict
]:
""" Extract a dict consisting of only ShardedTensor and ShardedTensorFactory objects from a given state dict with any objects.
Args:
sharded_state_dict: state dict possibly containing ShardedTensor and ShardedTensorFactory objects
Returns:
Tuple[ShardedStateDict, StateDict]: tuple of:
- state dict with all ShardedTensor and ShardedTensorFactory (keeping the original state dict structure)
- state dict with all other objects (keeping the original state dict structure)
"""
return
extract_matching_values
(
sharded_state_dict
,
lambda
v
:
isinstance
(
v
,
(
ShardedTensor
,
ShardedTensorFactory
))
)
def
extract_sharded_tensors_or_nonpersistent
(
sharded_state_dict
:
ShardedStateDict
,
)
->
Tuple
[
ShardedStateDict
,
StateDict
]:
""" Extract a dict consisting of only ShardedTensor, ShardedTensorFactory and LocalNonpersitentObject
objects from a given state dict with any objects.
Args:
sharded_state_dict: state dict possibly containing ShardedTensor, ShardedTensorFactory and LocalNonpersitentObject objects
Returns:
Tuple[ShardedStateDict, StateDict]: tuple of:
- state dict with all ShardedTensor, ShardedTensorFactory and LocalNonpersitentObject (keeping the original state dict structure)
- state dict with all other objects (keeping the original state dict structure)
"""
return
extract_matching_values
(
sharded_state_dict
,
lambda
v
:
isinstance
(
v
,
(
ShardedTensor
,
LocalNonpersitentObject
,
ShardedTensorFactory
)),
)
def
extract_sharded_base
(
sharded_state_dict
:
ShardedStateDict
,
)
->
Tuple
[
ShardedStateDict
,
StateDict
]:
return
extract_matching_values
(
sharded_state_dict
,
lambda
v
:
isinstance
(
v
,
ShardedBase
),)
def
extract_nonpersistent
(
sharded_state_dict
:
ShardedStateDict
,
)
->
Tuple
[
ShardedStateDict
,
StateDict
]:
return
extract_matching_values
(
sharded_state_dict
,
lambda
v
:
isinstance
(
v
,
LocalNonpersitentObject
),
)
def
add_prefix_for_sharding
(
sharded_state_dict
:
ShardedStateDict
,
prefix
:
str
):
""" Prepend a given prefix to all ShardedBase objects in a given state dict *in-place*.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict
prefix (str): prefix to be prepended
Returns:
None: state dict is modified in-place
"""
def
add_prefix
(
t
):
if
isinstance
(
t
,
ShardedBase
):
t
.
key
=
f
'
{
prefix
}{
t
.
key
}
'
return
t
dict_list_map_inplace
(
add_prefix
,
sharded_state_dict
)
def
replace_prefix_for_sharding
(
sharded_state_dict
:
ShardedStateDict
,
old_prefix
:
str
,
new_prefix
:
str
):
""" Replaces the given prefix in *all* sharded keys in a given state dict.
Errors out if some key does not begin with a given prefix.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to replace keys in
old_prefix (str): prefix to be replaced in each key
new_prefix (str): new prefix
Returns:
None: state dict is modified in place
"""
def
_replace_prefix
(
x
):
if
isinstance
(
x
,
(
ShardedTensor
,
ShardedTensorFactory
,
ShardedObject
)):
if
not
x
.
key
.
startswith
(
old_prefix
):
raise
ValueError
(
f
'Expected
{
x
.
key
}
to begin with prefix
{
old_prefix
}
'
)
x
.
key
=
f
'
{
new_prefix
}{
x
.
key
[
len
(
old_prefix
):]
}
'
# str.removeprefix in Python >= 3.9
return
x
dict_list_map_inplace
(
_replace_prefix
,
sharded_state_dict
)
def
apply_prefix_mapping
(
sharded_state_dict
:
ShardedStateDict
,
prefix_map
:
Dict
[
str
,
str
]):
""" Replaces prefixes *only in keys matching* with one of prefixes in the map.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to replace keys in
prefix_map (Dict[str, str]): map of old->new prefixes. The first matching prefix for each key is used
Returns:
None: state dict is modified in place
"""
def
_replace_prefixes
(
x
):
if
not
isinstance
(
x
,
(
ShardedTensor
,
ShardedTensorFactory
,
ShardedObject
)):
return
x
for
old_prefix
,
new_prefix
in
prefix_map
.
items
():
if
x
.
key
.
startswith
(
old_prefix
):
x
.
key
=
(
f
'
{
new_prefix
}{
x
.
key
[
len
(
old_prefix
):]
}
'
# str.removeprefix in Python >= 3.9
)
break
return
x
dict_list_map_inplace
(
_replace_prefixes
,
sharded_state_dict
)
megatron/core/distributed/__init__.py
0 → 100644
View file @
523ec9cc
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
.distributed_data_parallel
import
DistributedDataParallel
from
.distributed_data_parallel_config
import
DistributedDataParallelConfig
from
.finalize_model_grads
import
finalize_model_grads
from
.param_and_grad_buffer
import
ParamAndGradBuffer
,
shard_buffer
megatron/core/distributed/__pycache__/__init__.cpython-310.pyc
0 → 100644
View file @
523ec9cc
File added
megatron/core/distributed/__pycache__/distributed_data_parallel.cpython-310.pyc
0 → 100644
View file @
523ec9cc
File added
megatron/core/distributed/__pycache__/distributed_data_parallel_config.cpython-310.pyc
0 → 100644
View file @
523ec9cc
File added
megatron/core/distributed/__pycache__/finalize_model_grads.cpython-310.pyc
0 → 100644
View file @
523ec9cc
File added
megatron/core/distributed/__pycache__/param_and_grad_buffer.cpython-310.pyc
0 → 100644
View file @
523ec9cc
File added
megatron/core/distributed/distributed_data_parallel.py
0 → 100644
View file @
523ec9cc
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import
logging
from
contextlib
import
contextmanager
from
typing
import
Dict
,
Optional
import
torch
from
..
import
parallel_state
from
..transformer.module
import
MegatronModule
from
..transformer.transformer_config
import
TransformerConfig
from
..utils
import
log_single_rank
from
.distributed_data_parallel_config
import
DistributedDataParallelConfig
from
.param_and_grad_buffer
import
ParamAndGradBuffer
logger
=
logging
.
getLogger
(
__name__
)
class
DistributedDataParallel
(
MegatronModule
):
"""
DDP wrapper which stores grads in contiguous buffers. Also has option of overlapping
communication with backprop computation by breaking up full model's gradients into smaller
buckets and running all-reduce / reduce-scatter on each bucket asynchronously. This class
also provides the option to do the gradient accumulation in a type other than the param type
(e.g., fp32 for a bf16 model).
Args:
config: Transformer config object.
ddp_config: DistributedDataParallel config object.
module: Underlying model.
disable_bucketing: If true, force assign all parameters to a single bucket. If false,
use standard bucketing policy: assign parameters to smaller buckets and all-reduce
per bucket _if_ overlap_grad_reduce is True and pp_rank is 0.
"""
def
__init__
(
self
,
config
:
TransformerConfig
,
ddp_config
:
DistributedDataParallelConfig
,
module
:
torch
.
nn
.
Module
,
disable_bucketing
:
bool
=
False
,
):
super
().
__init__
(
config
=
config
)
self
.
module
=
module
# If bucket_size is not provided as an input, use sane default.
# If using very large dp_sizes, make buckets larger to ensure that chunks used in NCCL
# ring-reduce implementations are large enough to remain bandwidth-bound rather than
# latency-bound.
if
ddp_config
.
bucket_size
is
None
:
ddp_config
.
bucket_size
=
max
(
40000000
,
1000000
*
parallel_state
.
get_data_parallel_world_size
()
)
# Set bucket_size to infinity if overlap_grad_reduce is False.
if
not
ddp_config
.
overlap_grad_reduce
:
ddp_config
.
bucket_size
=
None
self
.
ddp_config
=
ddp_config
log_single_rank
(
logger
,
logging
.
INFO
,
f
'Setting up DistributedDataParallel with config
{
self
.
ddp_config
}
'
,
)
# Turn off bucketing if we are on a pipeline stage that is not the first (since
# data-parallel communication on these stages is not on the critical path), or if
# disable_bucketing is True (e.g., we might not want to break up model parameters
# into buckets for model chunks after the first in the interleaved schedule).
self
.
bucket_size
=
self
.
ddp_config
.
bucket_size
if
parallel_state
.
get_pipeline_model_parallel_rank
()
>
0
:
self
.
bucket_size
=
None
if
disable_bucketing
:
self
.
bucket_size
=
None
self
.
module
=
module
self
.
param_to_buffer
=
{}
# Group parameters by their gradient type.
param_to_name
=
{}
dense_params
=
[]
expert_parallel_params
=
[]
for
name
,
param
in
self
.
module
.
named_parameters
():
if
not
param
.
requires_grad
:
continue
param
.
grad_added_to_main_grad
=
False
param_to_name
[
param
]
=
name
if
getattr
(
param
,
'allreduce'
,
True
):
dense_params
.
append
(
param
)
else
:
expert_parallel_params
.
append
(
param
)
def
allocate_buffers_for_parameters
(
input_params
,
data_parallel_group
,
gradient_scaling_factor
,
):
param_and_grad_dtype_to_params
=
{}
# Group parameters by their gradient type.
for
param
in
input_params
:
if
not
param
.
requires_grad
:
continue
param_dtype
=
param
.
dtype
grad_dtype
=
torch
.
float
if
self
.
ddp_config
.
grad_reduce_in_fp32
else
param
.
dtype
params
=
param_and_grad_dtype_to_params
.
get
((
param_dtype
,
grad_dtype
),
[])
params
.
append
(
param
)
param_and_grad_dtype_to_params
[(
param_dtype
,
grad_dtype
)]
=
params
if
not
config
.
calculate_per_token_loss
:
target_gradient_scaling_factor
=
1.0
/
parallel_state
.
get_data_parallel_world_size
()
if
self
.
ddp_config
.
average_in_collective
:
# Collective is averaging gradients in collective with data_parallel_group.
assert
(
gradient_scaling_factor
/
torch
.
distributed
.
get_world_size
(
group
=
data_parallel_group
)
==
target_gradient_scaling_factor
)
else
:
assert
gradient_scaling_factor
==
target_gradient_scaling_factor
# Allocate the grad buffers and map the grads.
buffers
=
[]
for
(
param_dtype
,
grad_dtype
),
params
in
param_and_grad_dtype_to_params
.
items
():
buffers
.
append
(
ParamAndGradBuffer
(
self
.
ddp_config
,
param_dtype
,
grad_dtype
,
params
,
data_parallel_group
,
self
.
bucket_size
,
param_to_name
,
gradient_scaling_factor
,
)
)
for
param
in
params
:
self
.
param_to_buffer
[
param
]
=
buffers
[
-
1
]
return
buffers
if
config
.
calculate_per_token_loss
:
gradient_scaling_factor
=
1.0
expert_gradient_scaling_factor
=
1.0
else
:
if
self
.
ddp_config
.
average_in_collective
:
gradient_scaling_factor
=
1.0
expert_gradient_scaling_factor
=
(
1.0
/
parallel_state
.
get_expert_model_parallel_world_size
()
)
else
:
data_parallel_world_size
=
parallel_state
.
get_data_parallel_world_size
()
gradient_scaling_factor
=
1.0
/
data_parallel_world_size
expert_gradient_scaling_factor
=
1.0
/
data_parallel_world_size
# Allocate the param+grad buffers for dense params' grads.
self
.
buffers
=
allocate_buffers_for_parameters
(
dense_params
,
parallel_state
.
get_data_parallel_group
(
with_context_parallel
=
True
),
gradient_scaling_factor
=
gradient_scaling_factor
,
)
# Allocate separate param+grad buffers for expert parallel params' grads.
self
.
expert_parallel_buffers
=
allocate_buffers_for_parameters
(
expert_parallel_params
,
parallel_state
.
get_data_modulo_expert_parallel_group
(),
gradient_scaling_factor
=
expert_gradient_scaling_factor
,
)
# Delete references to weight_tensor if they exist since we don't want two parameter copies
# if we re-mapped parameters (which happens when we use the distributed optimizer).
# This is a temporary workaround around a TE bug that is fixed with
# https://github.com/NVIDIA/TransformerEngine/pull/719.
if
self
.
ddp_config
.
use_distributed_optimizer
:
@
torch
.
no_grad
()
def
unmap_weight_tensor
(
m
):
if
hasattr
(
m
,
'weight_tensor'
):
m
.
weight_tensor
=
None
self
.
module
.
apply
(
unmap_weight_tensor
)
# Register backward hook.
# Accumulation function for the gradients need to be stored so they
# don't go out of scope.
self
.
grad_accs
=
[]
for
param
in
self
.
module
.
parameters
():
if
param
.
requires_grad
:
# Expand so we get access to grad_fn.
param_tmp
=
param
.
expand_as
(
param
)
# Get the gradient accumulator function.
grad_acc
=
param_tmp
.
grad_fn
.
next_functions
[
0
][
0
]
grad_acc
.
register_hook
(
self
.
_make_param_hook
(
param
,
self
.
param_to_buffer
))
self
.
grad_accs
.
append
(
grad_acc
)
def
forward
(
self
,
*
inputs
,
**
kwargs
):
"""
Calls the wrapped module's forward() method.
"""
return
self
.
module
(
*
inputs
,
**
kwargs
)
def
_make_param_hook
(
self
,
param
:
torch
.
nn
.
Parameter
,
param_to_buffer
:
Dict
[
torch
.
nn
.
Parameter
,
ParamAndGradBuffer
],
):
"""
Creates the all-reduce / reduce-scatter hook for backprop.
"""
def
param_hook
(
*
unused
):
if
param
.
requires_grad
:
if
self
.
ddp_config
.
overlap_grad_reduce
:
assert
(
param
.
grad
is
not
None
),
'param.grad being None is not safe when overlap_grad_reduce is True'
if
param
.
grad
is
not
None
and
(
not
param
.
grad_added_to_main_grad
or
getattr
(
param
,
'zero_out_wgrad'
,
False
)
):
param
.
main_grad
.
add_
(
param
.
grad
.
data
)
param
.
grad
=
None
if
self
.
ddp_config
.
overlap_grad_reduce
:
param_to_buffer
[
param
].
register_grad_ready
(
param
)
return
param_hook
@
contextmanager
def
no_sync
(
self
):
"""
Context manager that turns off gradient synchronization.
"""
for
buffer
in
self
.
buffers
+
self
.
expert_parallel_buffers
:
buffer
.
is_last_microbatch
=
False
try
:
yield
finally
:
for
buffer
in
self
.
buffers
+
self
.
expert_parallel_buffers
:
buffer
.
is_last_microbatch
=
True
def
start_grad_sync
(
self
,
*
unused
):
"""
Initiates grad sync (all-reduce or reduce-scatter) communication operations
for all model gradients.
When overlap_grad_reduce is set to True, dispatches asynchronous communication
calls. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
for
buffer
in
self
.
buffers
+
self
.
expert_parallel_buffers
:
buffer
.
start_grad_sync
()
def
scale_gradients
(
self
,
scaling_factor
:
float
)
->
None
:
"""Scale all gradients inside the buffers by `scaling_factor`."""
for
buffer
in
self
.
buffers
+
self
.
expert_parallel_buffers
:
buffer
.
scale_gradients
(
scaling_factor
)
def
finish_grad_sync
(
self
):
"""
Finishes grad sync (all-reduce or reduce-scatter) communication operations
for all model gradients.
When overlap_grad_reduce is set to True, waits for asynchronous communication
calls to complete. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
for
buffer
in
self
.
buffers
+
self
.
expert_parallel_buffers
:
buffer
.
finish_grad_sync
()
def
zero_grad_buffer
(
self
):
"""
Zeros out all grad buffers. Needs to be called at the beginning of each
training iteration.
"""
for
param
in
self
.
module
.
parameters
():
if
param
.
requires_grad
:
param
.
grad_added_to_main_grad
=
False
for
buffer
in
self
.
buffers
+
self
.
expert_parallel_buffers
:
buffer
.
reset
()
def
broadcast_params
(
self
):
"""
Syncs parameters across all DP ranks.
"""
for
param
in
self
.
module
.
parameters
():
is_expert_parallel
=
not
getattr
(
param
,
'allreduce'
,
True
)
if
is_expert_parallel
:
data_parallel_group
=
parallel_state
.
get_data_modulo_expert_parallel_group
()
else
:
data_parallel_group
=
parallel_state
.
get_data_parallel_group
(
with_context_parallel
=
True
)
torch
.
distributed
.
broadcast
(
param
.
data
,
src
=
torch
.
distributed
.
get_global_rank
(
data_parallel_group
,
0
),
group
=
data_parallel_group
,
)
def
state_dict
(
self
,
prefix
=
''
,
keep_vars
=
False
):
"""
Returns a dictionary containing references to the whole state of the
wrapped module.
Both parameters and persistent buffers (e.g. running averages) are included.
Keys are corresponding parameter and buffer names. Parameters and buffers
set to None are not included.
"""
return
self
.
module
.
state_dict
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
def
state_dict_for_save_checkpoint
(
self
,
prefix
=
''
,
keep_vars
=
False
):
"""
Returns wrapped module's state_dict for checkpoint saving.
"""
return
self
.
module
.
state_dict_for_save_checkpoint
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""
Copies parameters and buffers from state_dict into the wrapped module and its
descendants. If strict is True, then the keys of state_dict must exactly match
the keys returned by this module’s state_dict() function.
"""
self
.
module
.
load_state_dict
(
state_dict
,
strict
=
strict
)
megatron/core/distributed/distributed_data_parallel_config.py
0 → 100644
View file @
523ec9cc
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
dataclasses
import
dataclass
from
typing
import
Optional
@
dataclass
class
DistributedDataParallelConfig
:
"""Configuration for DistributedDataParallel."""
grad_reduce_in_fp32
:
bool
=
False
"""If true, reduce grads in fp32."""
overlap_grad_reduce
:
bool
=
False
"""If true, overlap grad all-reduce / reduce-scatter with backward compute."""
use_distributed_optimizer
:
bool
=
False
"""If true, issue reduce-scatter collectives to aggregate gradients and clean up
originally allocated model parameters, otherwise issue all-reduce collectives.
"""
check_for_nan_in_grad
:
bool
=
False
""" If true, check for NaNs in gradients _before_ communication collective."""
bucket_size
:
Optional
[
int
]
=
None
"""Maximum number of parameters in each bucket. If unspecified, MCore uses a default
value of max(40000000, 1000000 * dp_size) parameters (larger DP sizes need larger
buckets to ensure collectives do not become latency-bound)."""
average_in_collective
:
bool
=
False
"""If true, compute average in collective directly, as opposed to dividing by the
dp_size first and then computing sum in the collective."""
megatron/core/distributed/finalize_model_grads.py
0 → 100644
View file @
523ec9cc
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
typing
import
List
,
Optional
import
torch
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
from
..
import
parallel_state
from
..transformer.transformer_config
import
TransformerConfig
from
..utils
import
get_attr_wrapped_model
,
get_model_config
def
_allreduce_word_embedding_grads
(
model
:
List
[
torch
.
nn
.
Module
],
config
:
TransformerConfig
):
"""
All-reduce word embedding grads.
Reduce grads across first and last stages to ensure that word_embeddings parameters stay in
sync. This should only run for models that support pipelined model parallelism (BERT and GPT).
"""
if
(
parallel_state
.
is_rank_in_embedding_group
(
ignore_virtual
=
True
)
and
parallel_state
.
get_pipeline_model_parallel_world_size
()
>
1
):
if
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
model_module
=
model
[
0
]
elif
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
model_module
=
model
[
-
1
]
else
:
# We do not support the interleaved schedule for T5 yet.
model_module
=
model
[
0
]
# Look for module with 'pre_process' attribute to get around the fact that DDP and
# other wrapper classes inherit from non-core MegatronModule that has
# 'share_embeddings_and_output_weights' and 'shared_embedding_or_output_weight'
# attributes already, causing get_attr_wrapped_model() to not unwrap anything here.
# TODO: Clean this up once the wrapper classes inherit from core MegatronModule.
model_module
=
get_attr_wrapped_model
(
model_module
,
'pre_process'
,
return_model_obj
=
True
)
if
model_module
.
share_embeddings_and_output_weights
:
weight
=
model_module
.
shared_embedding_or_output_weight
()
grad
=
weight
.
main_grad
torch
.
distributed
.
all_reduce
(
grad
,
group
=
parallel_state
.
get_embedding_group
())
def
_allreduce_position_embedding_grads
(
model
:
List
[
torch
.
nn
.
Module
],
config
:
TransformerConfig
):
"""
All-reduce position_embeddings grad across first (encoder) and split (decoder) stages to
ensure that position embeddings parameters stay in sync. This should only run for T5 models
with pipeline parallelism.
"""
if
(
parallel_state
.
is_rank_in_position_embedding_group
()
and
parallel_state
.
get_pipeline_model_parallel_world_size
()
>
1
and
config
.
pipeline_model_parallel_split_rank
is
not
None
):
model_module
=
model
[
0
]
grad
=
get_attr_wrapped_model
(
model_module
,
'language_model.embedding.position_embeddings.weight.main_grad'
)
torch
.
distributed
.
all_reduce
(
grad
,
group
=
parallel_state
.
get_position_embedding_group
())
def
_allreduce_embedding_grads
(
model
:
List
[
torch
.
nn
.
Module
],
config
:
TransformerConfig
):
"""
All-reduce both word and position embeddings.
"""
_allreduce_word_embedding_grads
(
model
,
config
)
_allreduce_position_embedding_grads
(
model
,
config
)
def
_allreduce_layernorm_grads
(
model
:
List
[
torch
.
nn
.
Module
],
config
:
TransformerConfig
):
"""
All-reduce layernorm grads (for sequence parallelism).
"""
# All-reduce layernorm parameters across model parallel nodes
# when sequence parallelism is used
if
parallel_state
.
get_tensor_model_parallel_world_size
()
>
1
and
(
config
.
sequence_parallel
or
config
.
qk_layernorm
):
grads
=
[]
for
model_chunk
in
model
:
for
name
,
param
in
get_attr_wrapped_model
(
model_chunk
,
'named_parameters'
)():
if
(
param
.
requires_grad
and
getattr
(
param
,
'sequence_parallel'
,
False
)
or
'q_layernorm'
in
name
or
'k_layernorm'
in
name
):
grad
=
param
.
main_grad
grads
.
append
(
grad
.
data
)
if
grads
:
coalesced
=
_flatten_dense_tensors
(
grads
)
torch
.
distributed
.
all_reduce
(
coalesced
,
group
=
parallel_state
.
get_tensor_model_parallel_group
()
)
for
buf
,
synced
in
zip
(
grads
,
_unflatten_dense_tensors
(
coalesced
,
grads
)):
buf
.
copy_
(
synced
)
def
finalize_model_grads
(
model
:
List
[
torch
.
nn
.
Module
],
num_tokens
:
Optional
[
torch
.
Tensor
]
=
None
):
"""
All-reduce all model grads across DP replicas, layernorm grads for sequence parallelism,
embedding grads across first and last pipeline stages (if not tied),
scale gradients by `num_tokens`.
"""
config
=
get_model_config
(
model
[
0
])
# All-reduce / reduce-scatter across DP replicas.
if
config
.
timers
is
not
None
:
config
.
timers
(
'all-grads-sync'
,
log_level
=
1
).
start
(
barrier
=
config
.
barrier_with_L1_time
)
for
model_chunk
in
model
:
model_chunk
.
finish_grad_sync
()
if
config
.
timers
is
not
None
:
config
.
timers
(
'all-grads-sync'
).
stop
()
# All-reduce layer-norm grads (for sequence parallelism).
if
config
.
timers
is
not
None
:
config
.
timers
(
'layernorm-grads-all-reduce'
,
log_level
=
1
).
start
(
barrier
=
config
.
barrier_with_L1_time
)
_allreduce_layernorm_grads
(
model
,
config
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'layernorm-grads-all-reduce'
).
stop
()
# All-reduce embedding grads (for pipeline parallelism).
if
config
.
timers
is
not
None
:
config
.
timers
(
'embedding-grads-all-reduce'
,
log_level
=
1
).
start
(
barrier
=
config
.
barrier_with_L1_time
)
_allreduce_embedding_grads
(
model
,
config
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'embedding-grads-all-reduce'
).
stop
()
# normalize gradients for per-token loss normalization.
# if we are using by the number of tokens, then we use that as a divisor. this number
# will be the total number of non-padded tokens in the global batch.
if
num_tokens
is
not
None
:
# the number of tokens is only present on the last stage, so broadcast it
# to the other ranks in the pipeline parallel group.
torch
.
distributed
.
broadcast
(
num_tokens
,
src
=
parallel_state
.
get_pipeline_model_parallel_last_rank
(),
group
=
parallel_state
.
get_pipeline_model_parallel_group
(),
)
# all-reduce across DP ranks.
torch
.
distributed
.
all_reduce
(
num_tokens
,
group
=
parallel_state
.
get_data_parallel_group
())
for
model_chunk
in
model
:
if
num_tokens
>
0
:
scaling
=
1.0
/
num_tokens
model_chunk
.
scale_gradients
(
scaling
)
Prev
1
…
7
8
9
10
11
12
13
14
15
…
21
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment